use crate::clients::AsyncEmbeddingClient;
use crate::common::{Chunk, Chunks, Embedding};
use crate::retrievers::traits::AsyncRetriever;
use pgvector::Vector;
use sqlx::{Pool, Postgres};
use std::error::Error;
use std::num::NonZeroU32;
use thiserror::Error;
pub struct PostgresVectorRetriever<T>
where
T: AsyncEmbeddingClient,
{
pool: Pool<Postgres>,
table_name: String,
embedding_client: T,
distance_function: DistanceFunction,
}
impl<T: AsyncEmbeddingClient> PostgresVectorRetriever<T> {
pub(crate) fn new(
pool: Pool<Postgres>,
table_name: String,
embedding_client: T,
distance_function: DistanceFunction,
) -> Self {
PostgresVectorRetriever {
pool,
table_name,
embedding_client,
distance_function,
}
}
fn select_row_sql(table_name: &str, distance_function: DistanceFunction) -> String {
format!(
"SELECT id, content, embedding, metadata FROM {} ORDER BY embedding {} $1::vector LIMIT $2",
table_name,
distance_function.to_sql_string()
)
}
}
impl<T> AsyncRetriever for PostgresVectorRetriever<T>
where
T: AsyncEmbeddingClient + Sync,
T::ErrorType: 'static,
{
type ErrorType = PostgresRetrieverError<T::ErrorType>;
async fn retrieve(&self, text: &str, top_k: NonZeroU32) -> Result<Chunks, Self::ErrorType> {
let k: i32 = top_k.get() as i32;
let chunk: Chunk = Chunk::new(text);
let embedding: Embedding = self
.embedding_client
.generate_embedding(chunk)
.await
.map_err(PostgresRetrieverError::EmbeddingClientError)?;
let query: String = Self::select_row_sql(&self.table_name, self.distance_function.clone());
let vector: Vec<f32> = embedding.vector();
let similar_text: Vec<PostgresRow> = sqlx::query_as::<_, PostgresRow>(&query)
.bind(vector)
.bind(k)
.fetch_all(&self.pool)
.await
.map_err(PostgresRetrieverError::QueryError)?;
Ok(similar_text
.into_iter()
.map(|row| Chunk::new_with_metadata(row.content, row.metadata))
.collect())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DistanceFunction {
L2,
Cosine,
InnerProduct,
}
#[derive(Debug, Clone, PartialEq, sqlx::FromRow)]
pub struct PostgresRow {
pub id: i32,
pub content: String,
pub embedding: Vector,
#[sqlx(json)]
pub metadata: serde_json::Value,
}
impl DistanceFunction {
pub fn to_sql_string(&self) -> &str {
match self {
DistanceFunction::L2 => "<->",
DistanceFunction::Cosine => "<=>",
DistanceFunction::InnerProduct => "<#>",
}
}
}
#[derive(Error, Debug)]
pub enum PostgresRetrieverError<T: Error> {
#[error("Embedding Client Error: {0}")]
EmbeddingClientError(T),
#[error("Embedding Retrieving Similar Text: {0}")]
QueryError(sqlx::Error),
}