rig_qdrant/
lib.rs

1mod filter;
2
3use filter::*;
4use qdrant_client::{
5    Payload, Qdrant,
6    qdrant::{
7        Filter, PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder,
8        point_id::PointIdOptions,
9    },
10};
11use rig::{
12    Embed, OneOrMany,
13    embeddings::{Embedding, EmbeddingModel},
14    vector_store::{
15        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
16    },
17};
18use serde::{Deserialize, Serialize};
19use uuid::Uuid;
20
21/// Represents a vector store implementation using Qdrant - <https://qdrant.tech/> as the backend.
22pub struct QdrantVectorStore<M: EmbeddingModel> {
23    /// Model used to generate embeddings for the vector store
24    model: M,
25    /// Client instance for Qdrant server communication
26    client: Qdrant,
27    /// Default search parameters
28    query_params: QueryPoints,
29}
30
31impl<M> QdrantVectorStore<M>
32where
33    M: EmbeddingModel,
34{
35    /// Creates a new instance of `QdrantVectorStore`.
36    ///
37    /// # Arguments
38    /// * `client` - Qdrant client instance
39    /// * `model` - Embedding model instance
40    /// * `query_params` - Search parameters for vector queries
41    ///   Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points>
42    pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
43        Self {
44            client,
45            model,
46            query_params,
47        }
48    }
49
50    pub fn client(&self) -> &Qdrant {
51        &self.client
52    }
53
54    /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format.
55    async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
56        let embedding = self.model.embed_text(query).await?;
57        Ok(embedding.vec.iter().map(|&x| x as f32).collect())
58    }
59
60    /// Fill in query parameters with the given query and limit.
61    fn prepare_query_params(
62        &self,
63        query: Option<Query>,
64        limit: usize,
65        threshold: Option<f64>,
66        filter: Option<Filter>,
67    ) -> QueryPoints {
68        let mut params = self.query_params.clone();
69        params.query = query;
70        params.limit = Some(limit as u64);
71        params.score_threshold = threshold.map(|x| x as f32);
72        params.filter = filter;
73        params
74    }
75}
76
77impl<Model> InsertDocuments for QdrantVectorStore<Model>
78where
79    Model: EmbeddingModel + Send + Sync,
80{
81    async fn insert_documents<Doc: Serialize + Embed + Send>(
82        &self,
83        documents: Vec<(Doc, OneOrMany<Embedding>)>,
84    ) -> Result<(), VectorStoreError> {
85        let collection_name = self.query_params.collection_name.clone();
86
87        for (document, embeddings) in documents {
88            let json_document = serde_json::to_value(&document).unwrap();
89            let doc_as_payload = Payload::try_from(json_document).unwrap();
90
91            let embeddings_as_point_structs = embeddings
92                .into_iter()
93                .map(|embedding| {
94                    let embedding_as_f32: Vec<f32> =
95                        embedding.vec.into_iter().map(|x| x as f32).collect();
96                    PointStruct::new(
97                        Uuid::new_v4().to_string(),
98                        embedding_as_f32,
99                        doc_as_payload.clone(),
100                    )
101                })
102                .collect::<Vec<PointStruct>>();
103
104            let request = UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs);
105            self.client.upsert_points(request).await.map_err(|err| {
106                VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into())
107            })?;
108        }
109
110        Ok(())
111    }
112}
113
114/// Converts a `PointId` to its string representation.
115fn stringify_id(id: PointId) -> Result<String, VectorStoreError> {
116    match id.point_id_options {
117        Some(PointIdOptions::Num(num)) => Ok(num.to_string()),
118        Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()),
119        None => Err(VectorStoreError::DatastoreError(
120            "Invalid point ID format".into(),
121        )),
122    }
123}
124
125impl<M> VectorStoreIndex for QdrantVectorStore<M>
126where
127    M: EmbeddingModel + std::marker::Sync + Send,
128{
129    type Filter = QdrantFilter;
130
131    /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store.
132    /// Returns a vector of tuples containing the score, ID, and payload of the nearest neighbors.
133    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
134        &self,
135        req: VectorSearchRequest<Self::Filter>,
136    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
137        let query = match self.query_params.query {
138            Some(ref q) => Some(q.clone()),
139            None => Some(Query::new_nearest(
140                self.generate_query_vector(req.query()).await?,
141            )),
142        };
143
144        let filter = req
145            .filter()
146            .as_ref()
147            .cloned()
148            .map(QdrantFilter::interpret)
149            .transpose()?
150            .flatten();
151
152        let params =
153            self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
154
155        let result = self
156            .client
157            .query(params)
158            .await
159            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
160
161        result
162            .result
163            .into_iter()
164            .map(|item| {
165                let id =
166                    stringify_id(item.id.ok_or_else(|| {
167                        VectorStoreError::DatastoreError("Missing point ID".into())
168                    })?)?;
169                let score = item.score as f64;
170                let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?;
171                Ok((score, id, payload))
172            })
173            .collect()
174    }
175
176    /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store.
177    /// Returns a vector of tuples containing the score and ID of the nearest neighbors.
178    async fn top_n_ids(
179        &self,
180        req: VectorSearchRequest<Self::Filter>,
181    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
182        let query = match self.query_params.query {
183            Some(ref q) => Some(q.clone()),
184            None => Some(Query::new_nearest(
185                self.generate_query_vector(req.query()).await?,
186            )),
187        };
188
189        let filter = req
190            .filter()
191            .as_ref()
192            .cloned()
193            .map(QdrantFilter::interpret)
194            .transpose()?
195            .flatten();
196
197        let params =
198            self.prepare_query_params(query, req.samples() as usize, req.threshold(), filter);
199
200        let points = self
201            .client
202            .query(params)
203            .await
204            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?
205            .result;
206
207        points
208            .into_iter()
209            .map(|point| {
210                let id =
211                    stringify_id(point.id.ok_or_else(|| {
212                        VectorStoreError::DatastoreError("Missing point ID".into())
213                    })?)?;
214                Ok((point.score as f64, id))
215            })
216            .collect()
217    }
218}