rag_toolchain/retrievers/
postgres_vector_retriever.rs

1use crate::clients::AsyncEmbeddingClient;
2use crate::common::{Chunk, Chunks, Embedding};
3use crate::retrievers::traits::AsyncRetriever;
4use pgvector::Vector;
5use sqlx::{Pool, Postgres};
6use std::error::Error;
7use std::num::NonZeroU32;
8use thiserror::Error;
9
10/// # [`PostgresVectorRetriever`]
11///
12/// This struct is a allows for the retrieval of similar text from a postgres database.
13/// It is parameterized over a type T which implements the AsyncEmbeddingClient trait.
14/// This is because text needs to be embeded before it can be compared to other text.
15/// You must connect first create a PostgresVectorStore as this handles connecting to the database.
16/// then you can calle .as_retriever() to convert it to retriever.
17///
18/// # Examples
19/// ```
20/// use rag_toolchain::retrievers::*;
21/// use rag_toolchain::clients::*;
22/// use rag_toolchain::common::*;
23/// use rag_toolchain::stores::*;
24/// use std::num::NonZeroU32;
25///
26/// async fn retrieve() {
27///     let chunk: Chunk = Chunk::new("This is the text you want to retrieve something similar to");
28///     let top_k: NonZeroU32 = NonZeroU32::new(5).unwrap();
29///     let distance_function: DistanceFunction = DistanceFunction::Cosine;
30///     let embedding_model: OpenAIEmbeddingModel = OpenAIEmbeddingModel::TextEmbedding3Small;
31///     let client: OpenAIEmbeddingClient = OpenAIEmbeddingClient::try_new(embedding_model).unwrap();
32///     let store: PostgresVectorStore = PostgresVectorStore::try_new("table_name", embedding_model).await.unwrap();
33///     let retriever: PostgresVectorRetriever<OpenAIEmbeddingClient> = store.as_retriever(client, distance_function);
34///     // This will return the top 5 most similar chunks to the input text.
35///     let similar_text: Chunks = retriever.retrieve(chunk.content(), top_k).await.unwrap();
36/// }
37/// ```
38pub struct PostgresVectorRetriever<T>
39where
40    T: AsyncEmbeddingClient,
41{
42    pool: Pool<Postgres>,
43    table_name: String,
44    embedding_client: T,
45    distance_function: DistanceFunction,
46}
47
48impl<T: AsyncEmbeddingClient> PostgresVectorRetriever<T> {
49    /// # [`PostgresVectorRetriever::new`]
50    /// This constructor is only used internally to allow .as_retriever methods to create a retriever.
51    ///
52    /// # Arguments
53    /// * `pool`: [`sqlx::Pool<Postgres>`] - Which we can use to interact with the database.
54    /// * `table_name`: [`String`] - The name of the table which contains the vectors.
55    /// * `embedding_client`: [`T`] - An instance of a type which implements the AsyncEmbeddingClient trait.
56    ///
57    /// # Returns
58    /// * [`PostgresVectorRetriever`] the created struct
59    pub(crate) fn new(
60        pool: Pool<Postgres>,
61        table_name: String,
62        embedding_client: T,
63        distance_function: DistanceFunction,
64    ) -> Self {
65        PostgresVectorRetriever {
66            pool,
67            table_name,
68            embedding_client,
69            distance_function,
70        }
71    }
72
73    /// # [`PostgresVectorRetriever::select_row_sql`]
74    ///
75    /// Helper function to genrate the sql query for a similarity search.
76    ///
77    /// # Arguments
78    /// * `table_name`: &[`str`] - The name of the table to search.
79    /// * `distance_function`: [`DistanceFunction`] - The distance function to use.
80    ///
81    /// # Returns
82    /// * [`String`] - The sql query.
83    fn select_row_sql(table_name: &str, distance_function: DistanceFunction) -> String {
84        format!(
85            "SELECT id, content, embedding, metadata FROM {} ORDER BY embedding {} $1::vector LIMIT $2",
86            table_name,
87            distance_function.to_sql_string()
88        )
89    }
90}
91
92impl<T> AsyncRetriever for PostgresVectorRetriever<T>
93where
94    T: AsyncEmbeddingClient + Sync,
95    T::ErrorType: 'static,
96{
97    // We parameterize over the error type of the embedding client.
98    type ErrorType = PostgresRetrieverError<T::ErrorType>;
99
100    /// # [`PostgresVectorRetriever::retrieve`]
101    ///
102    /// Implementation of the retrieve function for [`PostgresVectorRetriever`].
103    /// This allows us to retrieve similar text from the vector database.
104    ///
105    /// # Arguments
106    /// * `text`: &[`str`] - The text we are searching for similar text against.
107    /// * `top_k`: [`NonZeroU32`] - The number of results to return.
108    ///
109    /// # Errors
110    /// * [`PostgresRetrieverError::EmbeddingClientError`] - If the embedding client returns an error.
111    /// * [`PostgresRetrieverError::QueryError`] - If there is an error querying the database.
112    ///
113    /// # Returns
114    /// * [`Chunks`] which are the most similar to the input text.
115    async fn retrieve(&self, text: &str, top_k: NonZeroU32) -> Result<Chunks, Self::ErrorType> {
116        let k: i32 = top_k.get() as i32;
117        let chunk: Chunk = Chunk::new(text);
118        let embedding: Embedding = self
119            .embedding_client
120            .generate_embedding(chunk)
121            .await
122            .map_err(PostgresRetrieverError::EmbeddingClientError)?;
123
124        let query: String = Self::select_row_sql(&self.table_name, self.distance_function.clone());
125        let vector: Vec<f32> = embedding.vector();
126
127        let similar_text: Vec<PostgresRow> = sqlx::query_as::<_, PostgresRow>(&query)
128            .bind(vector)
129            .bind(k)
130            .fetch_all(&self.pool)
131            .await
132            .map_err(PostgresRetrieverError::QueryError)?;
133
134        Ok(similar_text
135            .into_iter()
136            .map(|row| Chunk::new_with_metadata(row.content, row.metadata))
137            .collect())
138    }
139}
140
141/// # [`DistanceFunction`]
142/// This is an enum for the types of distance functions
143/// that can be used to compare vectors.
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub enum DistanceFunction {
146    L2,
147    Cosine,
148    InnerProduct,
149}
150
151/// # [`PostgresRow`]
152/// Type that represents a row in our defined structure
153/// which allows us to use [`sqlx::query_as`].
154#[derive(Debug, Clone, PartialEq, sqlx::FromRow)]
155pub struct PostgresRow {
156    pub id: i32,
157    pub content: String,
158    pub embedding: Vector,
159    #[sqlx(json)]
160    pub metadata: serde_json::Value,
161}
162
163impl DistanceFunction {
164    pub fn to_sql_string(&self) -> &str {
165        match self {
166            DistanceFunction::L2 => "<->",
167            DistanceFunction::Cosine => "<=>",
168            DistanceFunction::InnerProduct => "<#>",
169        }
170    }
171}
172
173/// # [`PostgresRetrieverError`]
174///
175/// This error is generic as it is parameterized over the error type of the embedding client.
176/// This allows us to avoid dynamic dispatched error types.
177#[derive(Error, Debug)]
178pub enum PostgresRetrieverError<T: Error> {
179    /// If an error occured while trying to embed the text supplied
180    /// as an arguement
181    #[error("Embedding Client Error: {0}")]
182    EmbeddingClientError(T),
183    /// If an error occured while doing the similarity search
184    #[error("Embedding Retrieving Similar Text: {0}")]
185    QueryError(sqlx::Error),
186}