rag_toolchain/retrievers/postgres_vector_retriever.rs
1use crate::clients::AsyncEmbeddingClient;
2use crate::common::{Chunk, Chunks, Embedding};
3use crate::retrievers::traits::AsyncRetriever;
4use pgvector::Vector;
5use sqlx::{Pool, Postgres};
6use std::error::Error;
7use std::num::NonZeroU32;
8use thiserror::Error;
9
10/// # [`PostgresVectorRetriever`]
11///
12/// This struct is a allows for the retrieval of similar text from a postgres database.
13/// It is parameterized over a type T which implements the AsyncEmbeddingClient trait.
14/// This is because text needs to be embeded before it can be compared to other text.
15/// You must connect first create a PostgresVectorStore as this handles connecting to the database.
16/// then you can calle .as_retriever() to convert it to retriever.
17///
18/// # Examples
19/// ```
20/// use rag_toolchain::retrievers::*;
21/// use rag_toolchain::clients::*;
22/// use rag_toolchain::common::*;
23/// use rag_toolchain::stores::*;
24/// use std::num::NonZeroU32;
25///
26/// async fn retrieve() {
27/// let chunk: Chunk = Chunk::new("This is the text you want to retrieve something similar to");
28/// let top_k: NonZeroU32 = NonZeroU32::new(5).unwrap();
29/// let distance_function: DistanceFunction = DistanceFunction::Cosine;
30/// let embedding_model: OpenAIEmbeddingModel = OpenAIEmbeddingModel::TextEmbedding3Small;
31/// let client: OpenAIEmbeddingClient = OpenAIEmbeddingClient::try_new(embedding_model).unwrap();
32/// let store: PostgresVectorStore = PostgresVectorStore::try_new("table_name", embedding_model).await.unwrap();
33/// let retriever: PostgresVectorRetriever<OpenAIEmbeddingClient> = store.as_retriever(client, distance_function);
34/// // This will return the top 5 most similar chunks to the input text.
35/// let similar_text: Chunks = retriever.retrieve(chunk.content(), top_k).await.unwrap();
36/// }
37/// ```
38pub struct PostgresVectorRetriever<T>
39where
40 T: AsyncEmbeddingClient,
41{
42 pool: Pool<Postgres>,
43 table_name: String,
44 embedding_client: T,
45 distance_function: DistanceFunction,
46}
47
48impl<T: AsyncEmbeddingClient> PostgresVectorRetriever<T> {
49 /// # [`PostgresVectorRetriever::new`]
50 /// This constructor is only used internally to allow .as_retriever methods to create a retriever.
51 ///
52 /// # Arguments
53 /// * `pool`: [`sqlx::Pool<Postgres>`] - Which we can use to interact with the database.
54 /// * `table_name`: [`String`] - The name of the table which contains the vectors.
55 /// * `embedding_client`: [`T`] - An instance of a type which implements the AsyncEmbeddingClient trait.
56 ///
57 /// # Returns
58 /// * [`PostgresVectorRetriever`] the created struct
59 pub(crate) fn new(
60 pool: Pool<Postgres>,
61 table_name: String,
62 embedding_client: T,
63 distance_function: DistanceFunction,
64 ) -> Self {
65 PostgresVectorRetriever {
66 pool,
67 table_name,
68 embedding_client,
69 distance_function,
70 }
71 }
72
73 /// # [`PostgresVectorRetriever::select_row_sql`]
74 ///
75 /// Helper function to genrate the sql query for a similarity search.
76 ///
77 /// # Arguments
78 /// * `table_name`: &[`str`] - The name of the table to search.
79 /// * `distance_function`: [`DistanceFunction`] - The distance function to use.
80 ///
81 /// # Returns
82 /// * [`String`] - The sql query.
83 fn select_row_sql(table_name: &str, distance_function: DistanceFunction) -> String {
84 format!(
85 "SELECT id, content, embedding, metadata FROM {} ORDER BY embedding {} $1::vector LIMIT $2",
86 table_name,
87 distance_function.to_sql_string()
88 )
89 }
90}
91
92impl<T> AsyncRetriever for PostgresVectorRetriever<T>
93where
94 T: AsyncEmbeddingClient + Sync,
95 T::ErrorType: 'static,
96{
97 // We parameterize over the error type of the embedding client.
98 type ErrorType = PostgresRetrieverError<T::ErrorType>;
99
100 /// # [`PostgresVectorRetriever::retrieve`]
101 ///
102 /// Implementation of the retrieve function for [`PostgresVectorRetriever`].
103 /// This allows us to retrieve similar text from the vector database.
104 ///
105 /// # Arguments
106 /// * `text`: &[`str`] - The text we are searching for similar text against.
107 /// * `top_k`: [`NonZeroU32`] - The number of results to return.
108 ///
109 /// # Errors
110 /// * [`PostgresRetrieverError::EmbeddingClientError`] - If the embedding client returns an error.
111 /// * [`PostgresRetrieverError::QueryError`] - If there is an error querying the database.
112 ///
113 /// # Returns
114 /// * [`Chunks`] which are the most similar to the input text.
115 async fn retrieve(&self, text: &str, top_k: NonZeroU32) -> Result<Chunks, Self::ErrorType> {
116 let k: i32 = top_k.get() as i32;
117 let chunk: Chunk = Chunk::new(text);
118 let embedding: Embedding = self
119 .embedding_client
120 .generate_embedding(chunk)
121 .await
122 .map_err(PostgresRetrieverError::EmbeddingClientError)?;
123
124 let query: String = Self::select_row_sql(&self.table_name, self.distance_function.clone());
125 let vector: Vec<f32> = embedding.vector();
126
127 let similar_text: Vec<PostgresRow> = sqlx::query_as::<_, PostgresRow>(&query)
128 .bind(vector)
129 .bind(k)
130 .fetch_all(&self.pool)
131 .await
132 .map_err(PostgresRetrieverError::QueryError)?;
133
134 Ok(similar_text
135 .into_iter()
136 .map(|row| Chunk::new_with_metadata(row.content, row.metadata))
137 .collect())
138 }
139}
140
141/// # [`DistanceFunction`]
142/// This is an enum for the types of distance functions
143/// that can be used to compare vectors.
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub enum DistanceFunction {
146 L2,
147 Cosine,
148 InnerProduct,
149}
150
151/// # [`PostgresRow`]
152/// Type that represents a row in our defined structure
153/// which allows us to use [`sqlx::query_as`].
154#[derive(Debug, Clone, PartialEq, sqlx::FromRow)]
155pub struct PostgresRow {
156 pub id: i32,
157 pub content: String,
158 pub embedding: Vector,
159 #[sqlx(json)]
160 pub metadata: serde_json::Value,
161}
162
163impl DistanceFunction {
164 pub fn to_sql_string(&self) -> &str {
165 match self {
166 DistanceFunction::L2 => "<->",
167 DistanceFunction::Cosine => "<=>",
168 DistanceFunction::InnerProduct => "<#>",
169 }
170 }
171}
172
173/// # [`PostgresRetrieverError`]
174///
175/// This error is generic as it is parameterized over the error type of the embedding client.
176/// This allows us to avoid dynamic dispatched error types.
177#[derive(Error, Debug)]
178pub enum PostgresRetrieverError<T: Error> {
179 /// If an error occured while trying to embed the text supplied
180 /// as an arguement
181 #[error("Embedding Client Error: {0}")]
182 EmbeddingClientError(T),
183 /// If an error occured while doing the similarity search
184 #[error("Embedding Retrieving Similar Text: {0}")]
185 QueryError(sqlx::Error),
186}