rig_qdrant/
lib.rs

1use qdrant_client::{
2    Payload, Qdrant,
3    qdrant::{
4        PointId, PointStruct, Query, QueryPoints, UpsertPointsBuilder, point_id::PointIdOptions,
5    },
6};
7use rig::{
8    Embed, OneOrMany,
9    embeddings::{Embedding, EmbeddingModel},
10    vector_store::{
11        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
12    },
13};
14use serde::{Deserialize, Serialize};
15use uuid::Uuid;
16
17/// Represents a vector store implementation using Qdrant - <https://qdrant.tech/> as the backend.
18pub struct QdrantVectorStore<M: EmbeddingModel> {
19    /// Model used to generate embeddings for the vector store
20    model: M,
21    /// Client instance for Qdrant server communication
22    client: Qdrant,
23    /// Default search parameters
24    query_params: QueryPoints,
25}
26
27impl<M> QdrantVectorStore<M>
28where
29    M: EmbeddingModel,
30{
31    /// Creates a new instance of `QdrantVectorStore`.
32    ///
33    /// # Arguments
34    /// * `client` - Qdrant client instance
35    /// * `model` - Embedding model instance
36    /// * `query_params` - Search parameters for vector queries
37    ///   Reference: <https://api.qdrant.tech/v-1-12-x/api-reference/search/query-points>
38    pub fn new(client: Qdrant, model: M, query_params: QueryPoints) -> Self {
39        Self {
40            client,
41            model,
42            query_params,
43        }
44    }
45
46    pub fn client(&self) -> &Qdrant {
47        &self.client
48    }
49
50    /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format.
51    async fn generate_query_vector(&self, query: &str) -> Result<Vec<f32>, VectorStoreError> {
52        let embedding = self.model.embed_text(query).await?;
53        Ok(embedding.vec.iter().map(|&x| x as f32).collect())
54    }
55
56    /// Fill in query parameters with the given query and limit.
57    fn prepare_query_params(
58        &self,
59        query: Option<Query>,
60        limit: usize,
61        threshold: Option<f64>,
62    ) -> QueryPoints {
63        let mut params = self.query_params.clone();
64        params.query = query;
65        params.limit = Some(limit as u64);
66        params.score_threshold = threshold.map(|x| x as f32);
67        params
68    }
69}
70
71impl<Model> InsertDocuments for QdrantVectorStore<Model>
72where
73    Model: EmbeddingModel + Send + Sync,
74{
75    async fn insert_documents<Doc: Serialize + Embed + Send>(
76        &self,
77        documents: Vec<(Doc, OneOrMany<Embedding>)>,
78    ) -> Result<(), VectorStoreError> {
79        let collection_name = self.query_params.collection_name.clone();
80
81        for (document, embeddings) in documents {
82            let json_document = serde_json::to_value(&document).unwrap();
83            let doc_as_payload = Payload::try_from(json_document).unwrap();
84
85            let embeddings_as_point_structs = embeddings
86                .into_iter()
87                .map(|embedding| {
88                    let embedding_as_f32: Vec<f32> =
89                        embedding.vec.into_iter().map(|x| x as f32).collect();
90                    PointStruct::new(
91                        Uuid::new_v4().to_string(),
92                        embedding_as_f32,
93                        doc_as_payload.clone(),
94                    )
95                })
96                .collect::<Vec<PointStruct>>();
97
98            let request = UpsertPointsBuilder::new(&collection_name, embeddings_as_point_structs);
99            self.client.upsert_points(request).await.map_err(|err| {
100                VectorStoreError::DatastoreError(format!("Error while upserting: {err}").into())
101            })?;
102        }
103
104        Ok(())
105    }
106}
107
108/// Converts a `PointId` to its string representation.
109fn stringify_id(id: PointId) -> Result<String, VectorStoreError> {
110    match id.point_id_options {
111        Some(PointIdOptions::Num(num)) => Ok(num.to_string()),
112        Some(PointIdOptions::Uuid(uuid)) => Ok(uuid.to_string()),
113        None => Err(VectorStoreError::DatastoreError(
114            "Invalid point ID format".into(),
115        )),
116    }
117}
118
119impl<M> VectorStoreIndex for QdrantVectorStore<M>
120where
121    M: EmbeddingModel + std::marker::Sync + Send,
122{
123    /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store.
124    /// Returns a vector of tuples containing the score, ID, and payload of the nearest neighbors.
125    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
126        &self,
127        req: VectorSearchRequest,
128    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
129        let query = match self.query_params.query {
130            Some(ref q) => Some(q.clone()),
131            None => Some(Query::new_nearest(
132                self.generate_query_vector(req.query()).await?,
133            )),
134        };
135
136        let params = self.prepare_query_params(query, req.samples() as usize, req.threshold());
137        let result = self
138            .client
139            .query(params)
140            .await
141            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
142
143        result
144            .result
145            .into_iter()
146            .map(|item| {
147                let id =
148                    stringify_id(item.id.ok_or_else(|| {
149                        VectorStoreError::DatastoreError("Missing point ID".into())
150                    })?)?;
151                let score = item.score as f64;
152                let payload = serde_json::from_value(serde_json::to_value(item.payload)?)?;
153                Ok((score, id, payload))
154            })
155            .collect()
156    }
157
158    /// Search for the top `n` nearest neighbors to the given query within the Qdrant vector store.
159    /// Returns a vector of tuples containing the score and ID of the nearest neighbors.
160    async fn top_n_ids(
161        &self,
162        req: VectorSearchRequest,
163    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
164        let query = match self.query_params.query {
165            Some(ref q) => Some(q.clone()),
166            None => Some(Query::new_nearest(
167                self.generate_query_vector(req.query()).await?,
168            )),
169        };
170
171        let params = self.prepare_query_params(query, req.samples() as usize, req.threshold());
172        let points = self
173            .client
174            .query(params)
175            .await
176            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?
177            .result;
178
179        points
180            .into_iter()
181            .map(|point| {
182                let id =
183                    stringify_id(point.id.ok_or_else(|| {
184                        VectorStoreError::DatastoreError("Missing point ID".into())
185                    })?)?;
186                Ok((point.score as f64, id))
187            })
188            .collect()
189    }
190}