swiftide_integrations/qdrant/
persist.rs

1//! This module provides an implementation of the `Storage` trait for the `Qdrant` struct.
2//! It includes methods for setting up the storage, storing a single node, and storing a batch of
3//! nodes. This integration allows the Swiftide project to use Qdrant as a storage backend.
4
5use std::collections::HashSet;
6use swiftide_core::{
7    indexing::{EmbeddedField, IndexingStream, Persist, TextNode},
8    prelude::*,
9};
10
11use qdrant_client::qdrant::UpsertPointsBuilder;
12
13use super::{NodeWithVectors, Qdrant};
14
15#[async_trait]
16impl Persist for Qdrant {
17    type Input = String;
18    type Output = String;
19
20    /// Returns the batch size for the Qdrant storage.
21    ///
22    /// # Returns
23    ///
24    /// An `Option<usize>` representing the batch size if set, otherwise `None`.
25    fn batch_size(&self) -> Option<usize> {
26        self.batch_size
27    }
28
29    /// Sets up the Qdrant storage by creating the necessary index if it does not exist.
30    ///
31    /// # Returns
32    ///
33    /// A `Result<()>` which is `Ok` if the setup is successful, otherwise an error.
34    ///
35    /// # Errors
36    ///
37    /// This function will return an error if the index creation fails.
38    #[tracing::instrument(skip_all, err)]
39    async fn setup(&self) -> Result<()> {
40        tracing::debug!("Setting up Qdrant storage");
41        self.create_index_if_not_exists().await
42    }
43
44    /// Stores a single indexing node in the Qdrant storage.
45    ///
46    /// WARN: If running debug builds, the store is blocking and will impact performance
47    ///
48    /// # Parameters
49    ///
50    /// - `node`: The `TextNode` to be stored.
51    ///
52    /// # Returns
53    ///
54    /// A `Result<()>` which is `Ok` if the storage is successful, otherwise an error.
55    ///
56    /// # Errors
57    ///
58    /// This function will return an error if the node conversion or storage operation fails.
59    #[tracing::instrument(skip_all, err, name = "storage.qdrant.store")]
60    async fn store(&self, node: TextNode) -> Result<TextNode> {
61        let node_with_vectors = NodeWithVectors::new(&node, self.vector_fields());
62        let point = node_with_vectors.try_into()?;
63
64        tracing::debug!("Storing node");
65
66        self.client
67            .upsert_points(
68                UpsertPointsBuilder::new(self.collection_name.clone(), vec![point])
69                    .wait(cfg!(debug_assertions)),
70            )
71            .await?;
72        Ok(node)
73    }
74
75    /// Stores a batch of indexing nodes in the Qdrant storage.
76    ///
77    /// # Parameters
78    ///
79    /// - `nodes`: A vector of `TextNode` to be stored.
80    ///
81    /// # Returns
82    ///
83    /// A `Result<()>` which is `Ok` if the storage is successful, otherwise an error.
84    ///
85    /// # Errors
86    ///
87    /// This function will return an error if any node conversion or storage operation fails.
88    #[tracing::instrument(skip_all, name = "storage.qdrant.batch_store")]
89    async fn batch_store(&self, nodes: Vec<TextNode>) -> IndexingStream<String> {
90        let points = nodes
91            .iter()
92            .map(|node| NodeWithVectors::new(node, self.vector_fields()))
93            .map(NodeWithVectors::try_into)
94            .collect::<Result<Vec<_>>>();
95
96        let Ok(points) = points else {
97            return vec![Err(points.unwrap_err())].into();
98        };
99
100        tracing::debug!("Storing batch of {} nodes", points.len());
101
102        let result = self
103            .client
104            .upsert_points(
105                UpsertPointsBuilder::new(self.collection_name.clone(), points)
106                    .wait(cfg!(debug_assertions)),
107            )
108            .await;
109
110        if result.is_ok() {
111            IndexingStream::iter(nodes.into_iter().map(Ok))
112        } else {
113            vec![Err(result.unwrap_err().into())].into()
114        }
115    }
116}
117
118impl Qdrant {
119    fn vector_fields(&self) -> HashSet<&EmbeddedField> {
120        self.vectors.keys().collect::<HashSet<_>>()
121    }
122}