1use std::fmt::Display;
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
8 },
9};
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use serde_json::Value;
12use sqlx::PgPool;
13use uuid::Uuid;
14
15pub struct PostgresVectorStore<Model: EmbeddingModel> {
16 model: Model,
17 pg_pool: PgPool,
18 documents_table: String,
19 distance_function: PgVectorDistanceFunction,
20}
21
22pub enum PgVectorDistanceFunction {
31 L2,
32 InnerProduct,
33 Cosine,
34 L1,
35 Hamming,
36 Jaccard,
37}
38
39impl Display for PgVectorDistanceFunction {
40 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
41 match self {
42 PgVectorDistanceFunction::L2 => write!(f, "<->"),
43 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
44 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
45 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
46 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
47 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
48 }
49 }
50}
51
52#[derive(Debug, Deserialize, sqlx::FromRow)]
53pub struct SearchResult {
54 id: Uuid,
55 document: Value,
56 distance: f64,
58}
59
60#[derive(Debug, Deserialize, sqlx::FromRow)]
61pub struct SearchResultOnlyId {
62 id: Uuid,
63 distance: f64,
64}
65
66impl SearchResult {
67 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
68 let document: T =
69 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
70 Ok((self.distance, self.id.to_string(), document))
71 }
72}
73
74impl<Model> PostgresVectorStore<Model>
75where
76 Model: EmbeddingModel,
77{
78 pub fn new(
79 model: Model,
80 pg_pool: PgPool,
81 documents_table: Option<String>,
82 distance_function: PgVectorDistanceFunction,
83 ) -> Self {
84 Self {
85 model,
86 pg_pool,
87 documents_table: documents_table.unwrap_or(String::from("documents")),
88 distance_function,
89 }
90 }
91
92 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
93 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
94 }
95
96 fn search_query_full(&self, threshold: Option<f64>) -> String {
97 self.search_query(true, threshold)
98 }
99 fn search_query_only_ids(&self, threshold: Option<f64>) -> String {
100 self.search_query(false, threshold)
101 }
102
103 fn search_query(&self, with_document: bool, threshold: Option<f64>) -> String {
104 let document = if with_document { ", document" } else { "" };
105 format!(
106 "
107 SELECT id{}, distance FROM ( \
108 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
109 FROM {} \
110 {where_clause}
111 ORDER BY id, distance \
112 ) as d \
113 ORDER BY distance \
114 LIMIT $2",
115 document,
116 document,
117 self.distance_function,
118 self.documents_table,
119 where_clause = if let Some(threshold) = threshold {
120 format!("where distance > {threshold}")
121 } else {
122 String::new()
123 }
124 )
125 }
126}
127
128impl<Model> InsertDocuments for PostgresVectorStore<Model>
129where
130 Model: EmbeddingModel + Send + Sync,
131{
132 async fn insert_documents<Doc: Serialize + Embed + Send>(
133 &self,
134 documents: Vec<(Doc, OneOrMany<Embedding>)>,
135 ) -> Result<(), VectorStoreError> {
136 for (document, embeddings) in documents {
137 let id = Uuid::new_v4();
138 let json_document = serde_json::to_value(&document).unwrap();
139
140 for embedding in embeddings {
141 let embedding_text = embedding.document;
142 let embedding: Vec<f64> = embedding.vec;
143
144 sqlx::query(
145 format!(
146 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
147 self.documents_table
148 )
149 .as_str(),
150 )
151 .bind(id)
152 .bind(&json_document)
153 .bind(&embedding_text)
154 .bind(&embedding)
155 .execute(&self.pg_pool)
156 .await
157 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
158 }
159 }
160
161 Ok(())
162 }
163}
164
165impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
166where
167 Model: EmbeddingModel,
168{
169 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
172 &self,
173 req: VectorSearchRequest,
174 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
175 if req.samples() > i64::MAX as u64 {
176 return Err(VectorStoreError::DatastoreError(
177 format!(
178 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
179 i64::MAX
180 )
181 .into(),
182 ));
183 }
184
185 let embedded_query: pgvector::Vector = self
186 .model
187 .embed_text(req.query())
188 .await?
189 .vec
190 .iter()
191 .map(|&x| x as f32)
192 .collect::<Vec<f32>>()
193 .into();
194
195 let rows: Vec<SearchResult> =
196 sqlx::query_as(self.search_query_full(req.threshold()).as_str())
197 .bind(embedded_query)
198 .bind(req.samples() as i64)
199 .fetch_all(&self.pg_pool)
200 .await
201 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
202
203 let rows: Vec<(f64, String, T)> = rows
204 .into_iter()
205 .flat_map(SearchResult::into_result)
206 .collect();
207
208 Ok(rows)
209 }
210
211 async fn top_n_ids(
213 &self,
214 req: VectorSearchRequest,
215 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
216 if req.samples() > i64::MAX as u64 {
217 return Err(VectorStoreError::DatastoreError(
218 format!(
219 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
220 i64::MAX
221 )
222 .into(),
223 ));
224 }
225 let embedded_query: pgvector::Vector = self
226 .model
227 .embed_text(req.query())
228 .await?
229 .vec
230 .iter()
231 .map(|&x| x as f32)
232 .collect::<Vec<f32>>()
233 .into();
234
235 let rows: Vec<SearchResultOnlyId> =
236 sqlx::query_as(self.search_query_only_ids(req.threshold()).as_str())
237 .bind(embedded_query)
238 .bind(req.samples() as i64)
239 .fetch_all(&self.pg_pool)
240 .await
241 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
242
243 let rows: Vec<(f64, String)> = rows
244 .into_iter()
245 .map(|row| (row.distance, row.id.to_string()))
246 .collect();
247
248 Ok(rows)
249 }
250}