Skip to main content

rig_qdrant/
lib.rs

1//! Qdrant vector store integration for Rig.
2//!
3//! This crate provides [`QdrantVectorStore`], a Rig vector store index backed
4//! by Qdrant collections. It supports dense vector search and Qdrant filter
5//! expressions through [`QdrantFilter`].
6//!
7//! The root `rig` facade re-exports this crate as `rig::qdrant` when the
8//! `qdrant` feature is enabled.
9
10mod filter;
11
12pub use filter::QdrantFilter;
13use qdrant_client::{
14    Payload, Qdrant,
15    qdrant::{
16        Filter, PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder,
17        point_id::PointIdOptions,
18    },
19};
20use rig_core::{
21    Embed, OneOrMany,
22    embeddings::{Embedding, EmbeddingModel},
23    vector_store::{
24        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
25    },
26};
27use serde::{Deserialize, Serialize};
28use uuid::Uuid;
29
30/// Represents a vector store implementation using Qdrant - <https://qdrant.tech/> as the backend.
31pub struct QdrantVectorStore<M: EmbeddingModel> {
32    /// Model used to generate embeddings for the vector store
33    model: M,
34    /// Client instance for Qdrant server communication
35    client: Qdrant,
36    /// Default search parameters
37    query_params: QueryPoints,
38}
39
40impl<M> QdrantVectorStore<M>
41where
42    M: EmbeddingModel,
43{
44    /// Creates a new instance of `QdrantVectorStore`.
45    ///
46    /// # Arguments
47    /// * `client` - Qdrant client instance
48    /// * `model` - Embedding model instance
49    /// * `query_params` - Search parameters for vector queries
50    ///   Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points>
51    pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
52        Self {
53            client,
54            model,
55            query_params,
56        }
57    }
58
59    pub fn client(&self) -> &Qdrant {
60        &self.client
61    }
62
63    /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format.
64    async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
65        let embedding = self.model.embed_text(query).await?;
66        Ok(embedding.vec.iter().map(|&x| x as f32).collect())
67    }
68
69    /// Fill in query parameters with the given query and limit.
70    fn prepare_query_params(
71        &self,
72        query: Option<Query>,
73        limit: usize,
74        threshold: Option<f64>,
75        filter: Option<Filter>,
76    ) -> QueryPoints {
77        let mut params = self.query_params.clone();
78        params.query = query;
79        params.limit = Some(limit as u64);
80        params.score_threshold = threshold.map(|x| x as f32);
81        params.filter = filter;
82        params
83    }
84}
85
86impl<Model> InsertDocuments for QdrantVectorStore<Model>
87where
88    Model: EmbeddingModel + Send + Sync,
89{
90    async fn insert_documents<Doc: Serialize + Embed + Send>(
91        &self,
92        documents: Vec<(Doc, OneOrMany<Embedding>)>,
93    ) -> Result<(), VectorStoreError> {
94        let collection_name = self.query_params.collection_name.clone();
95
96        for (document, embeddings) in documents {
97            let json_document = serde_json::to_value(&document)?;
98            let doc_as_payload = Payload::try_from(json_document)
99                .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
100
101            let embeddings_as_point_structs = embeddings
102                .into_iter()
103                .map(|embedding| {
104                    let embedding_as_f32: Vec<f32> =
105                        embedding.vec.into_iter().map(|x| x as f32).collect();
106                    PointStruct::new(
107                        Uuid::new_v4().to_string(),
108                        embedding_as_f32,
109                        doc_as_payload.clone(),
110                    )
111                })
112                .collect::<Vec<PointStruct>>();
113
114            let request =
115                UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs).wait(true);
116            self.client.upsert_points(request).await.map_err(|err| {
117                VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into())
118            })?;
119        }
120
121        Ok(())
122    }
123}
124
125/// Converts a `PointId` to its string representation.
126fn stringify_id(id: PointId) -> Result<String, VectorStoreError> {
127    match id.point_id_options {
128        Some(PointIdOptions::Num(num)) => Ok(num.to_string()),
129        Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()),
130        None => Err(VectorStoreError::DatastoreError(
131            "Invalid point ID format".into(),
132        )),
133    }
134}
135
136impl<M> VectorStoreIndex for QdrantVectorStore<M>
137where
138    M: EmbeddingModel + std::marker::Sync + Send,
139{
140    type Filter = QdrantFilter;
141
142    /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store.
143    /// Returns a vector of tuples containing the score, ID, and payload of the nearest neighbors.
144    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
145        &self,
146        req: VectorSearchRequest<Self::Filter>,
147    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
148        let query = match self.query_params.query {
149            Some(ref q) => Some(q.clone()),
150            None => Some(Query::new_nearest(
151                self.generate_query_vector(req.query()).await?,
152            )),
153        };
154
155        let filter = req
156            .filter()
157            .as_ref()
158            .cloned()
159            .map(QdrantFilter::interpret)
160            .transpose()?
161            .flatten();
162
163        let params =
164            self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
165
166        let result = self
167            .client
168            .query(params)
169            .await
170            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
171
172        result
173            .result
174            .into_iter()
175            .map(|item| {
176                let id =
177                    stringify_id(item.id.ok_or_else(|| {
178                        VectorStoreError::DatastoreError("Missing point ID".into())
179                    })?)?;
180                let score = item.score as f64;
181                let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?;
182                Ok((score, id, payload))
183            })
184            .collect()
185    }
186
187    /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store.
188    /// Returns a vector of tuples containing the score and ID of the nearest neighbors.
189    async fn top_n_ids(
190        &self,
191        req: VectorSearchRequest<Self::Filter>,
192    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
193        let query = match self.query_params.query {
194            Some(ref q) => Some(q.clone()),
195            None => Some(Query::new_nearest(
196                self.generate_query_vector(req.query()).await?,
197            )),
198        };
199
200        let filter = req
201            .filter()
202            .as_ref()
203            .cloned()
204            .map(QdrantFilter::interpret)
205            .transpose()?
206            .flatten();
207
208        let params =
209            self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
210
211        let points = self
212            .client
213            .query(params)
214            .await
215            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?
216            .result;
217
218        points
219            .into_iter()
220            .map(|point| {
221                let id =
222                    stringify_id(point.id.ok_or_else(|| {
223                        VectorStoreError::DatastoreError("Missing point ID".into())
224                    })?)?;
225                Ok((point.score as f64, id))
226            })
227            .collect()
228    }
229}