rag_toolchain/stores/postgres_vector_store.rs
1use crate::clients::AsyncEmbeddingClient;
2use crate::common::{Chunk, Embedding, EmbeddingModel};
3use crate::retrievers::{DistanceFunction, PostgresVectorRetriever};
4use crate::stores::traits::EmbeddingStore;
5use sqlx::postgres::{PgPoolOptions, PgQueryResult};
6use sqlx::{postgres::PgArguments, Pool, Postgres};
7use std::env::{self, VarError};
8use thiserror::Error;
9
10use dotenv::dotenv;
11
12/// # [`PostgresVectorStore`]
13///
14/// This is the implementation of [`EmbeddingStore`] for a Postgres database with the
15/// pgvector extension enabled. This store takes a table name and an embedding model.
16/// If a table already exists with the same name and does not have the expected columns
17/// any calls to [`PostgresVectorStore::store`] or [`PostgresVectorStore::store_batch`]
18/// will fail.
19///
20/// # Required Environment Variables
21///
22/// * POSTGRES_USERNAME: The username to connect to the database with
23/// * POSTGRES_PASSWORD: The password to connect to the database with
24/// * POSTGRES_HOST: The host to connect to the database with
25/// * POSTGRES_DATABASE: The database to connect with
26///
27/// # Output table format
28/// Columns: | id (int) | content (text) | embedding (vector) | metadata (jsonb) |
29///
30/// # Examples
31/// ```
32/// use rag_toolchain::stores::*;
33/// use rag_toolchain::common::*;
34///
35/// async fn store(embeddings: Vec<Embedding>) {
36/// let embedding_model: OpenAIEmbeddingModel = OpenAIEmbeddingModel::TextEmbedding3Small;
37/// let table_name: &str = "table_name";
38/// let store: PostgresVectorStore = PostgresVectorStore::try_new(table_name, embedding_model)
39/// .await.unwrap();
40/// store.store_batch(embeddings).await.unwrap();
41/// }
42/// ```
43#[derive(Debug, Clone)]
44pub struct PostgresVectorStore {
45 /// We hold a connection pool to the database
46 pool: Pool<Postgres>,
47 /// The name of the table we are operating on
48 table_name: String,
49}
50
51impl PostgresVectorStore {
52 /// # [`PostgresVectorStore::try_new`]
53 ///
54 /// This constructor is used to create a new PostgresVectorStore. It will read the required
55 /// environment variables in. Try and connect to your postgres database and then create a table
56 /// with the given name and the expected columns. If the table already exists with the same name
57 /// it will not be re-created.
58 ///
59 /// # Arguments
60 /// * `table_name`: &[`str`] - The name of the table to store the embeddings in.
61 /// * `embedding_model`: impl [`EmbeddingModel`] - The embedding model to use to store the embeddings
62 ///
63 /// # Errors
64 /// * [`PostgresVectorError::EnvVarError`] if the required environment variables are not set.
65 /// * [`PostgresVectorError::ConnectionError`] if the connection to the database could not be established.
66 /// * [`PostgresVectorError::TableCreationError`] if the table could not be created.
67 ///
68 /// # Returns
69 /// * [`PostgresVectorStore`] if the connection and table creation is successful
70 pub async fn try_new(
71 table_name: &str,
72 embedding_model: impl EmbeddingModel,
73 ) -> Result<Self, PostgresVectorStoreError> {
74 dotenv().ok();
75 let username: String = env::var("POSTGRES_USER")?;
76 let password: String = env::var("POSTGRES_PASSWORD")?;
77 let host: String = env::var("POSTGRES_HOST")?;
78 let db_name: String = env::var("POSTGRES_DATABASE")?;
79
80 let embedding_diminsions = embedding_model.metadata().dimensions;
81 let connection_string =
82 format!("postgres://{}:{}@{}/{}", username, password, host, db_name);
83
84 // Connect to the database
85 let pool = PostgresVectorStore::connect(&connection_string)
86 .await
87 .map_err(PostgresVectorStoreError::ConnectionError)?;
88
89 // Create the table
90 PostgresVectorStore::create_table(&pool, table_name, embedding_diminsions)
91 .await
92 .map_err(PostgresVectorStoreError::TableCreationError)?;
93
94 Ok(PostgresVectorStore {
95 pool,
96 table_name: table_name.into(),
97 })
98 }
99
100 /// # [`PostgresVectorStore::try_new_with_pool`]
101 ///
102 /// This is an alternative constructor that allows you to pass in a connection pool.
103 /// This was added as it may be the case people want to establish one connection pool
104 /// and then shared it across multiple [`PostgresVectorStore`]s managing different tables.
105 ///
106 /// # Arguments
107 /// * `pool`: [`sqlx::Pool<Postgres>`] - a pre established connection pool.
108 /// * `table_name`: &[`str`] - The name of the table to store the embeddings in.
109 /// * `embedding_model`: impl[`EmbeddingModel`] - The embedding model used for the genrated embeddings.
110 ///
111 /// # Errors
112 /// * [`PostgresVectorError::TableCreationError`] if the table could not be created
113 ///
114 /// # Returns
115 /// * [`PostgresVectorStore`] if the table creation is successful.
116 pub async fn try_new_with_pool(
117 pool: Pool<Postgres>,
118 table_name: &str,
119 embedding_model: impl EmbeddingModel,
120 ) -> Result<Self, PostgresVectorStoreError> {
121 let embedding_diminsions = embedding_model.metadata().dimensions;
122
123 // Create the table
124 PostgresVectorStore::create_table(&pool, table_name, embedding_diminsions)
125 .await
126 .map_err(PostgresVectorStoreError::TableCreationError)?;
127
128 Ok(PostgresVectorStore {
129 pool,
130 table_name: table_name.into(),
131 })
132 }
133
134 /// # [`PostgresVectorStore::get_pool`]
135 ///
136 /// Getter for the internal connection pool.
137 /// This is useful if you want to do any further operations on the database
138 /// such as enabling an index on the table.
139 ///
140 /// # Returns
141 /// * [`Pool`] - The connection pool
142 pub fn get_pool(&self) -> Pool<Postgres> {
143 self.pool.clone()
144 }
145
146 /// # [`PostgresVectorStore::as_retriever`]
147 ///
148 /// This function allows us to convert the store into a retriever.
149 /// Note that the returned retriever is bound to the same table as the store.
150 ///
151 /// # Arguments
152 /// * `embedding_client`: [`AsyncEmbeddingClient`] - The client we use to embed
153 /// ` income text before the similarity search.
154 /// * `distance_function`: [`DistanceFunction`] - The distance function to use to
155 /// compare the embeddings
156 ///
157 /// # Returns
158 /// [`PostgresVectorRetriever`] - The retriever that can be used to search for similar text.
159 pub fn as_retriever<T: AsyncEmbeddingClient>(
160 &self,
161 embedding_client: T,
162 distance_function: DistanceFunction,
163 ) -> PostgresVectorRetriever<T> {
164 PostgresVectorRetriever::new(
165 self.pool.clone(),
166 self.table_name.clone(),
167 embedding_client,
168 distance_function,
169 )
170 }
171
172 /// # [`PostgresVectorStore::connect`]
173 /// Allows us to establish a connection to a database and store the connection pool
174 ///
175 /// # Arguments
176 /// * `connection_string`: &[`str`] - The connection string to use to connect to the database
177 ///
178 /// # Errors
179 /// * [`sqlx::Error`] if the connection could not be established.
180 ///
181 /// # Returns
182 /// * [`Pool`] which can be used to query the database
183 async fn connect(connection_string: &str) -> Result<Pool<Postgres>, sqlx::Error> {
184 let pool: Pool<Postgres> = PgPoolOptions::new()
185 .max_connections(5)
186 .connect(connection_string)
187 .await?;
188 Ok(pool)
189 }
190
191 /// # [`PostgresVectorStore::create_table`]
192 /// We call the create table automatically when the struct is created
193 ///
194 /// # Arguments
195 /// * `pool`: [`sqlx::Pool<Postgres>`] - The connection pool to use to create the table
196 /// * `table_name`: &[`str`] - The name of the table to create
197 /// * `vector_dimension`: [`usize`] - The dimension of the vector to store
198 ///
199 /// # Errors
200 /// * [`sqlx::Error`] if the table could not be created.
201 ///
202 /// # Returns
203 /// * [`PgQueryResult`] which can be used to check if the table was created successfully
204 async fn create_table(
205 pool: &Pool<Postgres>,
206 table_name: &str,
207 vector_dimension: usize,
208 ) -> Result<PgQueryResult, sqlx::Error> {
209 let statement = format!(
210 "CREATE TABLE IF NOT EXISTS {} (
211 id SERIAL PRIMARY KEY,
212 content TEXT NOT NULL,
213 embedding VECTOR({}) NOT NULL,
214 metadata JSONB
215 )",
216 table_name, vector_dimension
217 );
218 sqlx::query(&statement).execute(pool).await
219 }
220
221 /// # [`PostgresVectorStore::insert_row_sql`]
222 /// Helper function to generate the sql query for inserting a new row
223 ///
224 /// # Arguments
225 /// * `table_name`: &[`str`] - The name of the table to insert into
226 ///
227 /// # Returns
228 /// * [`String`] - The sql query
229 fn insert_row_sql(table_name: &str) -> String {
230 format!(
231 "INSERT INTO {} (content, embedding, metadata) VALUES ($1, $2, $3)",
232 table_name
233 )
234 }
235
236 /// # [`PostgresVectorStore::bind_to_query`]
237 /// Helper function to bind an [`Embedding`] to an [`sqlx::query::Query`]
238 /// the retuned query can then have [`sqlx::query::Query::execute`] called on it to
239 /// insert the row.
240 fn bind_to_query(
241 query: &str,
242 embedding: Embedding,
243 ) -> sqlx::query::Query<'_, Postgres, PgArguments> {
244 let chunk: &Chunk = embedding.chunk();
245 let text: String = chunk.content().to_string();
246 let metadata = chunk.metadata().clone();
247 let vector: Vec<f32> = embedding.vector();
248 sqlx::query(query).bind(text).bind(vector).bind(metadata)
249 }
250}
251
252impl EmbeddingStore for PostgresVectorStore {
253 type ErrorType = PostgresVectorStoreError;
254 /// # [`PostgresVectorStore::store`]
255 /// This is done as a single insert statement.
256 ///
257 /// # Arguments
258 /// * `embedding`: [`Embedding`] - to insert
259 ///
260 /// # Errors
261 /// * [`PostgresVectorError::InsertError`] if the insert fails
262 ///
263 /// # Returns
264 /// * [`()`] if the insert succeeds
265 async fn store(&self, embedding: Embedding) -> Result<(), PostgresVectorStoreError> {
266 let query: String = PostgresVectorStore::insert_row_sql(&self.table_name);
267 Self::bind_to_query(&query, embedding)
268 .execute(&self.pool)
269 .await
270 .map_err(PostgresVectorStoreError::InsertError)?;
271 Ok(())
272 }
273
274 /// # [`PostgresVectorStore::store_batch`]
275 /// This is done as a single transaction with multiple insert statements.
276 ///
277 /// # Arguments
278 /// * `embeddings`: [`Vec<Embedding>`] - A vector of embeddings to insert
279 ///
280 /// # Errors
281 /// * [`PostgresVectorError::TransactionError`] if the transaction fails
282 ///
283 /// # Returns
284 /// * [`()`] if the transaction succeeds
285 async fn store_batch(
286 &self,
287 embeddings: Vec<Embedding>,
288 ) -> Result<(), PostgresVectorStoreError> {
289 let query: String = PostgresVectorStore::insert_row_sql(&self.table_name);
290 let mut transaction = self
291 .pool
292 .begin()
293 .await
294 .map_err(PostgresVectorStoreError::TransactionError)?;
295
296 for embedding in embeddings {
297 Self::bind_to_query(&query, embedding)
298 .execute(&mut *transaction)
299 .await
300 .map_err(PostgresVectorStoreError::InsertError)?;
301 }
302
303 transaction
304 .commit()
305 .await
306 .map_err(PostgresVectorStoreError::TransactionError)?;
307 Ok(())
308 }
309}
310
311/// # [`PostgresVectorError`]
312/// This Error enum wraps all the errors that can occur when using
313/// the PgVector struct with contextual meaning.
314#[derive(Error, Debug)]
315pub enum PostgresVectorStoreError {
316 /// Error when an environment variable is not set
317 #[error("Environment Variable Error: {0}")]
318 EnvVarError(VarError),
319 /// Error when the connection to the database could not be established
320 #[error("Connection Error: {0}")]
321 ConnectionError(sqlx::Error),
322 /// Error when the table could not be created
323 #[error("Table Creation Error: {0}")]
324 TableCreationError(sqlx::Error),
325 /// Error when calling [`PostgresVectorStore::store()`] fails
326 #[error("Upsert Error: {0}")]
327 InsertError(sqlx::Error),
328 /// Error when calling [`PostgresVectorStore::store_batch()`] fails
329 #[error("Transaction Error: {0}")]
330 TransactionError(sqlx::Error),
331}
332
333impl From<VarError> for PostgresVectorStoreError {
334 fn from(error: VarError) -> Self {
335 PostgresVectorStoreError::EnvVarError(error)
336 }
337}
338
339#[cfg(all(test, feature = "pg_vector"))]
340mod tests {
341 use super::*;
342 use crate::common::OpenAIEmbeddingModel::TextEmbeddingAda002;
343
344 #[tokio::test]
345 async fn test_throws_correct_errors() {
346 let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
347 .await
348 .unwrap_err();
349 assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
350
351 std::env::set_var("POSTGRES_USER", "postgres");
352 let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
353 .await
354 .unwrap_err();
355 assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
356
357 std::env::set_var("POSTGRES_PASSWORD", "postgres");
358 let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
359 .await
360 .unwrap_err();
361 assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
362
363 std::env::set_var("POSTGRES_HOST", "localhost");
364 let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
365 .await
366 .unwrap_err();
367 assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
368
369 std::env::set_var("POSTGRES_DATABASE", "postgres");
370 let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
371 .await
372 .unwrap_err();
373 assert!(matches!(
374 result,
375 PostgresVectorStoreError::ConnectionError(_)
376 ));
377 }
378}