1use std::fmt::Display;
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
8 },
9};
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use serde_json::Value;
12use sqlx::PgPool;
13use uuid::Uuid;
14
15pub struct PostgresVectorStore<Model: EmbeddingModel> {
16 model: Model,
17 pg_pool: PgPool,
18 documents_table: String,
19 distance_function: PgVectorDistanceFunction,
20}
21
22pub enum PgVectorDistanceFunction {
31 L2,
32 InnerProduct,
33 Cosine,
34 L1,
35 Hamming,
36 Jaccard,
37}
38
39impl Display for PgVectorDistanceFunction {
40 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
41 match self {
42 PgVectorDistanceFunction::L2 => write!(f, "<->"),
43 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
44 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
45 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
46 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
47 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
48 }
49 }
50}
51
52#[derive(Debug, Deserialize, sqlx::FromRow)]
53pub struct SearchResult {
54 id: Uuid,
55 document: Value,
56 distance: f64,
58}
59
60#[derive(Debug, Deserialize, sqlx::FromRow)]
61pub struct SearchResultOnlyId {
62 id: Uuid,
63 distance: f64,
64}
65
66impl SearchResult {
67 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
68 let document: T =
69 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
70 Ok((self.distance, self.id.to_string(), document))
71 }
72}
73
74impl<Model: EmbeddingModel> PostgresVectorStore<Model> {
75 pub fn new(
76 model: Model,
77 pg_pool: PgPool,
78 documents_table: Option<String>,
79 distance_function: PgVectorDistanceFunction,
80 ) -> Self {
81 Self {
82 model,
83 pg_pool,
84 documents_table: documents_table.unwrap_or(String::from("documents")),
85 distance_function,
86 }
87 }
88
89 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
90 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
91 }
92
93 fn search_query_full(&self) -> String {
94 self.search_query(true)
95 }
96 fn search_query_only_ids(&self) -> String {
97 self.search_query(false)
98 }
99
100 fn search_query(&self, with_document: bool) -> String {
101 let document = if with_document { ", document" } else { "" };
102 format!(
103 "
104 SELECT id{}, distance FROM ( \
105 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
106 FROM {} \
107 ORDER BY id, distance \
108 ) as d \
109 ORDER BY distance \
110 LIMIT $2",
111 document, document, self.distance_function, self.documents_table
112 )
113 }
114}
115
116impl<Model> InsertDocuments for PostgresVectorStore<Model>
117where
118 Model: EmbeddingModel + Send + Sync,
119{
120 async fn insert_documents<Doc: Serialize + Embed + Send>(
121 &self,
122 documents: Vec<(Doc, OneOrMany<Embedding>)>,
123 ) -> Result<(), VectorStoreError> {
124 for (document, embeddings) in documents {
125 let id = Uuid::new_v4();
126 let json_document = serde_json::to_value(&document).unwrap();
127
128 for embedding in embeddings {
129 let embedding_text = embedding.document;
130 let embedding: Vec<f64> = embedding.vec;
131
132 sqlx::query(
133 format!(
134 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
135 self.documents_table
136 )
137 .as_str(),
138 )
139 .bind(id)
140 .bind(&json_document)
141 .bind(&embedding_text)
142 .bind(&embedding)
143 .execute(&self.pg_pool)
144 .await
145 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
146 }
147 }
148
149 Ok(())
150 }
151}
152
153impl<Model: EmbeddingModel> VectorStoreIndex for PostgresVectorStore<Model> {
154 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
157 &self,
158 req: VectorSearchRequest,
159 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
160 if req.samples() > i64::MAX as u64 {
161 return Err(VectorStoreError::DatastoreError(
162 format!(
163 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
164 i64::MAX
165 )
166 .into(),
167 ));
168 }
169
170 let embedded_query: pgvector::Vector = self
171 .model
172 .embed_text(req.query())
173 .await?
174 .vec
175 .iter()
176 .map(|&x| x as f32)
177 .collect::<Vec<f32>>()
178 .into();
179
180 let rows: Vec<SearchResult> = sqlx::query_as(self.search_query_full().as_str())
181 .bind(embedded_query)
182 .bind(req.samples() as i64)
183 .fetch_all(&self.pg_pool)
184 .await
185 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
186
187 let rows: Vec<(f64, String, T)> = rows
188 .into_iter()
189 .flat_map(SearchResult::into_result)
190 .collect();
191
192 Ok(rows)
193 }
194
195 async fn top_n_ids(
197 &self,
198 req: VectorSearchRequest,
199 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
200 if req.samples() > i64::MAX as u64 {
201 return Err(VectorStoreError::DatastoreError(
202 format!(
203 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
204 i64::MAX
205 )
206 .into(),
207 ));
208 }
209 let embedded_query: pgvector::Vector = self
210 .model
211 .embed_text(req.query())
212 .await?
213 .vec
214 .iter()
215 .map(|&x| x as f32)
216 .collect::<Vec<f32>>()
217 .into();
218
219 let rows: Vec<SearchResultOnlyId> = sqlx::query_as(self.search_query_only_ids().as_str())
220 .bind(embedded_query)
221 .bind(req.samples() as i64)
222 .fetch_all(&self.pg_pool)
223 .await
224 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
225
226 let rows: Vec<(f64, String)> = rows
227 .into_iter()
228 .map(|row| (row.distance, row.id.to_string()))
229 .collect();
230
231 Ok(rows)
232 }
233}