rag_toolchain/stores/
postgres_vector_store.rs

1use crate::clients::AsyncEmbeddingClient;
2use crate::common::{Chunk, Embedding, EmbeddingModel};
3use crate::retrievers::{DistanceFunction, PostgresVectorRetriever};
4use crate::stores::traits::EmbeddingStore;
5use sqlx::postgres::{PgPoolOptions, PgQueryResult};
6use sqlx::{postgres::PgArguments, Pool, Postgres};
7use std::env::{self, VarError};
8use thiserror::Error;
9
10use dotenv::dotenv;
11
12/// # [`PostgresVectorStore`]
13///
14/// This is the implementation of [`EmbeddingStore`] for a Postgres database with the
15/// pgvector extension enabled. This store takes a table name and an embedding model.
16/// If a table already exists with the same name and does not have the expected columns
17/// any calls to [`PostgresVectorStore::store`] or [`PostgresVectorStore::store_batch`]
18/// will fail.
19///
20/// # Required Environment Variables
21///
22/// * POSTGRES_USERNAME: The username to connect to the database with
23/// * POSTGRES_PASSWORD: The password to connect to the database with
24/// * POSTGRES_HOST: The host to connect to the database with
25/// * POSTGRES_DATABASE: The database to connect with
26///
27/// # Output table format
28/// Columns: | id (int) | content (text) | embedding (vector) | metadata (jsonb) |
29///
30/// # Examples
31/// ```
32/// use rag_toolchain::stores::*;
33/// use rag_toolchain::common::*;
34///
35/// async fn store(embeddings: Vec<Embedding>) {
36///     let embedding_model: OpenAIEmbeddingModel = OpenAIEmbeddingModel::TextEmbedding3Small;
37///     let table_name: &str = "table_name";
38///     let store: PostgresVectorStore = PostgresVectorStore::try_new(table_name, embedding_model)
39///         .await.unwrap();
40///     store.store_batch(embeddings).await.unwrap();
41/// }
42/// ```
43#[derive(Debug, Clone)]
44pub struct PostgresVectorStore {
45    /// We hold a connection pool to the database
46    pool: Pool<Postgres>,
47    /// The name of the table we are operating on
48    table_name: String,
49}
50
51impl PostgresVectorStore {
52    /// # [`PostgresVectorStore::try_new`]
53    ///
54    /// This constructor is used to create a new PostgresVectorStore. It will read the required
55    /// environment variables in. Try and connect to your postgres database and then create a table
56    /// with the given name and the expected columns. If the table already exists with the same name
57    /// it will not be re-created.
58    ///
59    /// # Arguments
60    /// * `table_name`: &[`str`] - The name of the table to store the embeddings in.
61    /// * `embedding_model`: impl [`EmbeddingModel`] - The embedding model to use to store the embeddings
62    ///
63    /// # Errors
64    /// * [`PostgresVectorError::EnvVarError`] if the required environment variables are not set.
65    /// * [`PostgresVectorError::ConnectionError`] if the connection to the database could not be established.
66    /// * [`PostgresVectorError::TableCreationError`] if the table could not be created.
67    ///
68    /// # Returns
69    /// * [`PostgresVectorStore`] if the connection and table creation is successful
70    pub async fn try_new(
71        table_name: &str,
72        embedding_model: impl EmbeddingModel,
73    ) -> Result<Self, PostgresVectorStoreError> {
74        dotenv().ok();
75        let username: String = env::var("POSTGRES_USER")?;
76        let password: String = env::var("POSTGRES_PASSWORD")?;
77        let host: String = env::var("POSTGRES_HOST")?;
78        let db_name: String = env::var("POSTGRES_DATABASE")?;
79
80        let embedding_diminsions = embedding_model.metadata().dimensions;
81        let connection_string =
82            format!("postgres://{}:{}@{}/{}", username, password, host, db_name);
83
84        // Connect to the database
85        let pool = PostgresVectorStore::connect(&connection_string)
86            .await
87            .map_err(PostgresVectorStoreError::ConnectionError)?;
88
89        // Create the table
90        PostgresVectorStore::create_table(&pool, table_name, embedding_diminsions)
91            .await
92            .map_err(PostgresVectorStoreError::TableCreationError)?;
93
94        Ok(PostgresVectorStore {
95            pool,
96            table_name: table_name.into(),
97        })
98    }
99
100    /// # [`PostgresVectorStore::try_new_with_pool`]
101    ///
102    /// This is an alternative constructor that allows you to pass in a connection pool.
103    /// This was added as it may be the case people want to establish one connection pool
104    /// and then shared it across multiple [`PostgresVectorStore`]s managing different tables.
105    ///
106    /// # Arguments
107    /// * `pool`: [`sqlx::Pool<Postgres>`] - a pre established connection pool.
108    /// * `table_name`: &[`str`] - The name of the table to store the embeddings in.
109    /// * `embedding_model`: impl[`EmbeddingModel`] - The embedding model used for the genrated embeddings.
110    ///
111    /// # Errors
112    /// * [`PostgresVectorError::TableCreationError`] if the table could not be created
113    ///
114    /// # Returns
115    /// * [`PostgresVectorStore`] if the table creation is successful.
116    pub async fn try_new_with_pool(
117        pool: Pool<Postgres>,
118        table_name: &str,
119        embedding_model: impl EmbeddingModel,
120    ) -> Result<Self, PostgresVectorStoreError> {
121        let embedding_diminsions = embedding_model.metadata().dimensions;
122
123        // Create the table
124        PostgresVectorStore::create_table(&pool, table_name, embedding_diminsions)
125            .await
126            .map_err(PostgresVectorStoreError::TableCreationError)?;
127
128        Ok(PostgresVectorStore {
129            pool,
130            table_name: table_name.into(),
131        })
132    }
133
134    /// # [`PostgresVectorStore::get_pool`]
135    ///
136    /// Getter for the internal connection pool.
137    /// This is useful if you want to do any further operations on the database
138    /// such as enabling an index on the table.
139    ///
140    /// # Returns
141    /// * [`Pool`] - The connection pool
142    pub fn get_pool(&self) -> Pool<Postgres> {
143        self.pool.clone()
144    }
145
146    /// # [`PostgresVectorStore::as_retriever`]
147    ///
148    /// This function allows us to convert the store into a retriever.
149    /// Note that the returned retriever is bound to the same table as the store.
150    ///
151    /// # Arguments
152    /// * `embedding_client`: [`AsyncEmbeddingClient`] - The client we use to embed
153    ///     `                  income text before the similarity search.
154    /// * `distance_function`: [`DistanceFunction`] - The distance function to use to
155    ///                        compare the embeddings
156    ///
157    /// # Returns
158    /// [`PostgresVectorRetriever`] - The retriever that can be used to search for similar text.
159    pub fn as_retriever<T: AsyncEmbeddingClient>(
160        &self,
161        embedding_client: T,
162        distance_function: DistanceFunction,
163    ) -> PostgresVectorRetriever<T> {
164        PostgresVectorRetriever::new(
165            self.pool.clone(),
166            self.table_name.clone(),
167            embedding_client,
168            distance_function,
169        )
170    }
171
172    /// # [`PostgresVectorStore::connect`]
173    /// Allows us to establish a connection to a database and store the connection pool
174    ///
175    /// # Arguments
176    /// * `connection_string`: &[`str`] - The connection string to use to connect to the database
177    ///
178    /// # Errors
179    /// * [`sqlx::Error`] if the connection could not be established.
180    ///
181    /// # Returns
182    /// * [`Pool`] which can be used to query the database
183    async fn connect(connection_string: &str) -> Result<Pool<Postgres>, sqlx::Error> {
184        let pool: Pool<Postgres> = PgPoolOptions::new()
185            .max_connections(5)
186            .connect(connection_string)
187            .await?;
188        Ok(pool)
189    }
190
191    /// # [`PostgresVectorStore::create_table`]
192    /// We call the create table automatically when the struct is created
193    ///
194    /// # Arguments
195    /// * `pool`: [`sqlx::Pool<Postgres>`] - The connection pool to use to create the table
196    /// * `table_name`: &[`str`] - The name of the table to create
197    /// * `vector_dimension`: [`usize`] - The dimension of the vector to store
198    ///
199    /// # Errors
200    /// * [`sqlx::Error`] if the table could not be created.
201    ///
202    /// # Returns
203    /// * [`PgQueryResult`] which can be used to check if the table was created successfully
204    async fn create_table(
205        pool: &Pool<Postgres>,
206        table_name: &str,
207        vector_dimension: usize,
208    ) -> Result<PgQueryResult, sqlx::Error> {
209        let statement = format!(
210            "CREATE TABLE IF NOT EXISTS {} (
211                id SERIAL PRIMARY KEY,
212                content TEXT NOT NULL,
213                embedding VECTOR({}) NOT NULL,
214                metadata JSONB
215            )",
216            table_name, vector_dimension
217        );
218        sqlx::query(&statement).execute(pool).await
219    }
220
221    /// # [`PostgresVectorStore::insert_row_sql`]
222    /// Helper function to generate the sql query for inserting a new row
223    ///
224    /// # Arguments
225    /// * `table_name`: &[`str`] - The name of the table to insert into
226    ///
227    /// # Returns
228    /// * [`String`] - The sql query
229    fn insert_row_sql(table_name: &str) -> String {
230        format!(
231            "INSERT INTO {} (content, embedding, metadata) VALUES ($1, $2, $3)",
232            table_name
233        )
234    }
235
236    /// # [`PostgresVectorStore::bind_to_query`]
237    /// Helper function to bind an [`Embedding`] to an [`sqlx::query::Query`]
238    /// the retuned query can then have [`sqlx::query::Query::execute`] called on it to
239    /// insert the row.
240    fn bind_to_query(
241        query: &str,
242        embedding: Embedding,
243    ) -> sqlx::query::Query<'_, Postgres, PgArguments> {
244        let chunk: &Chunk = embedding.chunk();
245        let text: String = chunk.content().to_string();
246        let metadata = chunk.metadata().clone();
247        let vector: Vec<f32> = embedding.vector();
248        sqlx::query(query).bind(text).bind(vector).bind(metadata)
249    }
250}
251
252impl EmbeddingStore for PostgresVectorStore {
253    type ErrorType = PostgresVectorStoreError;
254    /// # [`PostgresVectorStore::store`]
255    /// This is done as a single insert statement.
256    ///
257    /// # Arguments
258    /// * `embedding`: [`Embedding`] - to insert
259    ///
260    /// # Errors
261    /// * [`PostgresVectorError::InsertError`] if the insert fails
262    ///
263    /// # Returns
264    /// * [`()`] if the insert succeeds
265    async fn store(&self, embedding: Embedding) -> Result<(), PostgresVectorStoreError> {
266        let query: String = PostgresVectorStore::insert_row_sql(&self.table_name);
267        Self::bind_to_query(&query, embedding)
268            .execute(&self.pool)
269            .await
270            .map_err(PostgresVectorStoreError::InsertError)?;
271        Ok(())
272    }
273
274    /// # [`PostgresVectorStore::store_batch`]
275    /// This is done as a single transaction with multiple insert statements.
276    ///
277    /// # Arguments
278    /// * `embeddings`: [`Vec<Embedding>`] - A vector of embeddings to insert
279    ///
280    /// # Errors
281    /// * [`PostgresVectorError::TransactionError`] if the transaction fails
282    ///
283    /// # Returns
284    /// * [`()`] if the transaction succeeds
285    async fn store_batch(
286        &self,
287        embeddings: Vec<Embedding>,
288    ) -> Result<(), PostgresVectorStoreError> {
289        let query: String = PostgresVectorStore::insert_row_sql(&self.table_name);
290        let mut transaction = self
291            .pool
292            .begin()
293            .await
294            .map_err(PostgresVectorStoreError::TransactionError)?;
295
296        for embedding in embeddings {
297            Self::bind_to_query(&query, embedding)
298                .execute(&mut *transaction)
299                .await
300                .map_err(PostgresVectorStoreError::InsertError)?;
301        }
302
303        transaction
304            .commit()
305            .await
306            .map_err(PostgresVectorStoreError::TransactionError)?;
307        Ok(())
308    }
309}
310
311/// # [`PostgresVectorError`]
312/// This Error enum wraps all the errors that can occur when using
313/// the PgVector struct with contextual meaning.
314#[derive(Error, Debug)]
315pub enum PostgresVectorStoreError {
316    /// Error when an environment variable is not set
317    #[error("Environment Variable Error: {0}")]
318    EnvVarError(VarError),
319    /// Error when the connection to the database could not be established
320    #[error("Connection Error: {0}")]
321    ConnectionError(sqlx::Error),
322    /// Error when the table could not be created
323    #[error("Table Creation Error: {0}")]
324    TableCreationError(sqlx::Error),
325    /// Error when calling [`PostgresVectorStore::store()`] fails
326    #[error("Upsert Error: {0}")]
327    InsertError(sqlx::Error),
328    /// Error when calling [`PostgresVectorStore::store_batch()`] fails
329    #[error("Transaction Error: {0}")]
330    TransactionError(sqlx::Error),
331}
332
333impl From<VarError> for PostgresVectorStoreError {
334    fn from(error: VarError) -> Self {
335        PostgresVectorStoreError::EnvVarError(error)
336    }
337}
338
339#[cfg(all(test, feature = "pg_vector"))]
340mod tests {
341    use super::*;
342    use crate::common::OpenAIEmbeddingModel::TextEmbeddingAda002;
343
344    #[tokio::test]
345    async fn test_throws_correct_errors() {
346        let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
347            .await
348            .unwrap_err();
349        assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
350
351        std::env::set_var("POSTGRES_USER", "postgres");
352        let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
353            .await
354            .unwrap_err();
355        assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
356
357        std::env::set_var("POSTGRES_PASSWORD", "postgres");
358        let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
359            .await
360            .unwrap_err();
361        assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
362
363        std::env::set_var("POSTGRES_HOST", "localhost");
364        let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
365            .await
366            .unwrap_err();
367        assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
368
369        std::env::set_var("POSTGRES_DATABASE", "postgres");
370        let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
371            .await
372            .unwrap_err();
373        assert!(matches!(
374            result,
375            PostgresVectorStoreError::ConnectionError(_)
376        ));
377    }
378}