use crate::clients::AsyncEmbeddingClient;
use crate::common::{Chunk, Embedding, EmbeddingModel};
use crate::retrievers::{DistanceFunction, PostgresVectorRetriever};
use crate::stores::traits::EmbeddingStore;
use sqlx::postgres::{PgPoolOptions, PgQueryResult};
use sqlx::{postgres::PgArguments, Pool, Postgres};
use std::env::{self, VarError};
use thiserror::Error;
use dotenv::dotenv;
#[derive(Debug, Clone)]
pub struct PostgresVectorStore {
pool: Pool<Postgres>,
table_name: String,
}
impl PostgresVectorStore {
pub async fn try_new(
table_name: &str,
embedding_model: impl EmbeddingModel,
) -> Result<Self, PostgresVectorStoreError> {
dotenv().ok();
let username: String = env::var("POSTGRES_USER")?;
let password: String = env::var("POSTGRES_PASSWORD")?;
let host: String = env::var("POSTGRES_HOST")?;
let db_name: String = env::var("POSTGRES_DATABASE")?;
let embedding_diminsions = embedding_model.metadata().dimensions;
let connection_string =
format!("postgres://{}:{}@{}/{}", username, password, host, db_name);
let pool = PostgresVectorStore::connect(&connection_string)
.await
.map_err(PostgresVectorStoreError::ConnectionError)?;
PostgresVectorStore::create_table(&pool, table_name, embedding_diminsions)
.await
.map_err(PostgresVectorStoreError::TableCreationError)?;
Ok(PostgresVectorStore {
pool,
table_name: table_name.into(),
})
}
pub async fn try_new_with_pool(
pool: Pool<Postgres>,
table_name: &str,
embedding_model: impl EmbeddingModel,
) -> Result<Self, PostgresVectorStoreError> {
let embedding_diminsions = embedding_model.metadata().dimensions;
PostgresVectorStore::create_table(&pool, table_name, embedding_diminsions)
.await
.map_err(PostgresVectorStoreError::TableCreationError)?;
Ok(PostgresVectorStore {
pool,
table_name: table_name.into(),
})
}
pub fn get_pool(&self) -> Pool<Postgres> {
self.pool.clone()
}
pub fn as_retriever<T: AsyncEmbeddingClient>(
&self,
embedding_client: T,
distance_function: DistanceFunction,
) -> PostgresVectorRetriever<T> {
PostgresVectorRetriever::new(
self.pool.clone(),
self.table_name.clone(),
embedding_client,
distance_function,
)
}
async fn connect(connection_string: &str) -> Result<Pool<Postgres>, sqlx::Error> {
let pool: Pool<Postgres> = PgPoolOptions::new()
.max_connections(5)
.connect(connection_string)
.await?;
Ok(pool)
}
async fn create_table(
pool: &Pool<Postgres>,
table_name: &str,
vector_dimension: usize,
) -> Result<PgQueryResult, sqlx::Error> {
let statement = format!(
"CREATE TABLE IF NOT EXISTS {} (
id SERIAL PRIMARY KEY,
content TEXT NOT NULL,
embedding VECTOR({}) NOT NULL,
metadata JSONB
)",
table_name, vector_dimension
);
sqlx::query(&statement).execute(pool).await
}
fn insert_row_sql(table_name: &str) -> String {
format!(
"INSERT INTO {} (content, embedding, metadata) VALUES ($1, $2, $3)",
table_name
)
}
fn bind_to_query(
query: &str,
embedding: Embedding,
) -> sqlx::query::Query<'_, Postgres, PgArguments> {
let chunk: &Chunk = embedding.chunk();
let text: String = chunk.content().to_string();
let metadata = chunk.metadata().clone();
let vector: Vec<f32> = embedding.vector();
sqlx::query(query).bind(text).bind(vector).bind(metadata)
}
}
impl EmbeddingStore for PostgresVectorStore {
type ErrorType = PostgresVectorStoreError;
async fn store(&self, embedding: Embedding) -> Result<(), PostgresVectorStoreError> {
let query: String = PostgresVectorStore::insert_row_sql(&self.table_name);
Self::bind_to_query(&query, embedding)
.execute(&self.pool)
.await
.map_err(PostgresVectorStoreError::InsertError)?;
Ok(())
}
async fn store_batch(
&self,
embeddings: Vec<Embedding>,
) -> Result<(), PostgresVectorStoreError> {
let query: String = PostgresVectorStore::insert_row_sql(&self.table_name);
let mut transaction = self
.pool
.begin()
.await
.map_err(PostgresVectorStoreError::TransactionError)?;
for embedding in embeddings {
Self::bind_to_query(&query, embedding)
.execute(&mut *transaction)
.await
.map_err(PostgresVectorStoreError::InsertError)?;
}
transaction
.commit()
.await
.map_err(PostgresVectorStoreError::TransactionError)?;
Ok(())
}
}
#[derive(Error, Debug)]
pub enum PostgresVectorStoreError {
#[error("Environment Variable Error: {0}")]
EnvVarError(VarError),
#[error("Connection Error: {0}")]
ConnectionError(sqlx::Error),
#[error("Table Creation Error: {0}")]
TableCreationError(sqlx::Error),
#[error("Upsert Error: {0}")]
InsertError(sqlx::Error),
#[error("Transaction Error: {0}")]
TransactionError(sqlx::Error),
}
impl From<VarError> for PostgresVectorStoreError {
fn from(error: VarError) -> Self {
PostgresVectorStoreError::EnvVarError(error)
}
}
#[cfg(all(test, feature = "pg_vector"))]
mod tests {
use super::*;
use crate::common::OpenAIEmbeddingModel::TextEmbeddingAda002;
#[tokio::test]
async fn test_throws_correct_errors() {
let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
.await
.unwrap_err();
assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
std::env::set_var("POSTGRES_USER", "postgres");
let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
.await
.unwrap_err();
assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
std::env::set_var("POSTGRES_PASSWORD", "postgres");
let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
.await
.unwrap_err();
assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
std::env::set_var("POSTGRES_HOST", "localhost");
let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
.await
.unwrap_err();
assert!(matches!(result, PostgresVectorStoreError::EnvVarError(_)));
std::env::set_var("POSTGRES_DATABASE", "postgres");
let result = PostgresVectorStore::try_new("test", TextEmbeddingAda002)
.await
.unwrap_err();
assert!(matches!(
result,
PostgresVectorStoreError::ConnectionError(_)
));
}
}