1use std::{fmt::Display, ops::RangeInclusive};
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, VectorStoreError, VectorStoreIndex,
8 request::{SearchFilter, VectorSearchRequest},
9 },
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use serde_json::Value;
13use sqlx::PgPool;
14use uuid::Uuid;
15
16pub struct PostgresVectorStore<Model: EmbeddingModel> {
17 model: Model,
18 pg_pool: PgPool,
19 documents_table: String,
20 distance_function: PgVectorDistanceFunction,
21}
22
23pub enum PgVectorDistanceFunction {
32 L2,
33 InnerProduct,
34 Cosine,
35 L1,
36 Hamming,
37 Jaccard,
38}
39
40impl Display for PgVectorDistanceFunction {
41 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42 match self {
43 PgVectorDistanceFunction::L2 => write!(f, "<->"),
44 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
45 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
46 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
47 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
48 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
49 }
50 }
51}
52
53#[derive(Clone, Default)]
54pub struct PgSearchFilter {
55 condition: String,
56 values: Vec<serde_json::Value>,
57}
58
59impl SearchFilter for PgSearchFilter {
60 type Value = serde_json::Value;
61
62 fn eq(key: String, value: Self::Value) -> Self {
63 Self {
64 condition: format!("{key} = $"),
65 values: vec![value],
66 }
67 }
68
69 fn gt(key: String, value: Self::Value) -> Self {
70 Self {
71 condition: format!("{key} > $"),
72 values: vec![value],
73 }
74 }
75
76 fn lt(key: String, value: Self::Value) -> Self {
77 Self {
78 condition: format!("{key} < $"),
79 values: vec![value],
80 }
81 }
82
83 fn and(self, rhs: Self) -> Self {
84 Self {
85 condition: format!("({}) AND ({})", self.condition, rhs.condition),
86 values: self.values.into_iter().chain(rhs.values).collect(),
87 }
88 }
89
90 fn or(self, rhs: Self) -> Self {
91 Self {
92 condition: format!("({}) OR ({})", self.condition, rhs.condition),
93 values: self.values.into_iter().chain(rhs.values).collect(),
94 }
95 }
96}
97
98impl PgSearchFilter {
99 fn into_clause(self) -> (String, Vec<serde_json::Value>) {
100 (self.condition, self.values)
101 }
102
103 #[allow(clippy::should_implement_trait)]
104 pub fn not(self) -> Self {
105 Self {
106 condition: format!("NOT ({})", self.condition),
107 values: self.values,
108 }
109 }
110
111 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
112 Self {
113 condition: format!("{key} >= ?"),
114 values: vec![value],
115 }
116 }
117
118 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
119 Self {
120 condition: format!("{key} <= ?"),
121 values: vec![value],
122 }
123 }
124
125 pub fn is_null(key: String) -> Self {
126 Self {
127 condition: format!("{key} is null"),
128 ..Default::default()
129 }
130 }
131
132 pub fn is_not_null(key: String) -> Self {
133 Self {
134 condition: format!("{key} is not null"),
135 ..Default::default()
136 }
137 }
138
139 pub fn between<T>(key: String, range: RangeInclusive<T>) -> Self
140 where
141 T: std::fmt::Display + Into<serde_json::Number> + Copy,
142 {
143 let lo = range.start();
144 let hi = range.end();
145
146 Self {
147 condition: format!("{key} between {lo} and {hi}"),
148 ..Default::default()
149 }
150 }
151
152 pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153 let placeholders = values.iter().map(|_| "?").collect::<Vec<&str>>().join(",");
154
155 Self {
156 condition: format!("{key} is in ({placeholders})"),
157 values,
158 }
159 }
160
161 pub fn like(key: String, pattern: &'static str) -> Self {
166 Self {
167 condition: format!("{key} like {pattern}"),
168 ..Default::default()
169 }
170 }
171
172 pub fn similar_to(key: String, pattern: &'static str) -> Self {
175 Self {
176 condition: format!("{key} similar to {pattern}"),
177 ..Default::default()
178 }
179 }
180}
181
182#[derive(Debug, Deserialize, sqlx::FromRow)]
183pub struct SearchResult {
184 id: Uuid,
185 document: Value,
186 distance: f64,
188}
189
190#[derive(Debug, Deserialize, sqlx::FromRow)]
191pub struct SearchResultOnlyId {
192 id: Uuid,
193 distance: f64,
194}
195
196impl SearchResult {
197 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
198 let document: T =
199 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
200 Ok((self.distance, self.id.to_string(), document))
201 }
202}
203
204impl<Model> PostgresVectorStore<Model>
205where
206 Model: EmbeddingModel,
207{
208 pub fn new(
209 model: Model,
210 pg_pool: PgPool,
211 documents_table: Option<String>,
212 distance_function: PgVectorDistanceFunction,
213 ) -> Self {
214 Self {
215 model,
216 pg_pool,
217 documents_table: documents_table.unwrap_or(String::from("documents")),
218 distance_function,
219 }
220 }
221
222 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
223 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
224 }
225
226 fn search_query_full(
227 &self,
228 req: &VectorSearchRequest<PgSearchFilter>,
229 ) -> (String, Vec<serde_json::Value>) {
230 self.search_query(true, req)
231 }
232
233 fn search_query_only_ids(
234 &self,
235 req: &VectorSearchRequest<PgSearchFilter>,
236 ) -> (String, Vec<serde_json::Value>) {
237 self.search_query(false, req)
238 }
239
240 fn search_query(
241 &self,
242 with_document: bool,
243 req: &VectorSearchRequest<PgSearchFilter>,
244 ) -> (String, Vec<serde_json::Value>) {
245 let document = if with_document { ", document" } else { "" };
246
247 let thresh = req
248 .threshold()
249 .map(|t| PgSearchFilter::gt("distance".into(), t.into()));
250 let filter = match (thresh, req.filter()) {
251 (Some(thresh), Some(filt)) => Some(thresh.and(filt.clone())),
252 (Some(thresh), _) => Some(thresh),
253 (_, Some(filt)) => Some(filt.clone()),
254 _ => None,
255 };
256 let (where_clause, params) = match filter {
257 Some(f) => {
258 let (expr, params) = f.into_clause();
259 (String::from("WHERE") + &expr, params)
260 }
261 None => (Default::default(), Default::default()),
262 };
263
264 let mut counter = 3;
265 let mut buf = String::with_capacity(where_clause.len() * 2);
266
267 for c in where_clause.chars() {
268 buf.push(c);
269
270 if c == '$' {
271 buf.push_str(counter.to_string().as_str());
272 counter += 1;
273 }
274 }
275
276 let where_clause = buf;
277
278 let query = format!(
279 "
280 SELECT id{}, distance FROM ( \
281 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
282 FROM {} \
283 {where_clause} \
284 ORDER BY id, distance \
285 ) as d \
286 ORDER BY distance \
287 LIMIT $2",
288 document, document, self.distance_function, self.documents_table
289 );
290
291 (query, params)
292 }
293}
294
295impl<Model> InsertDocuments for PostgresVectorStore<Model>
296where
297 Model: EmbeddingModel + Send + Sync,
298{
299 async fn insert_documents<Doc: Serialize + Embed + Send>(
300 &self,
301 documents: Vec<(Doc, OneOrMany<Embedding>)>,
302 ) -> Result<(), VectorStoreError> {
303 for (document, embeddings) in documents {
304 let id = Uuid::new_v4();
305 let json_document = serde_json::to_value(&document).unwrap();
306
307 for embedding in embeddings {
308 let embedding_text = embedding.document;
309 let embedding: Vec<f64> = embedding.vec;
310
311 sqlx::query(
312 format!(
313 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
314 self.documents_table
315 )
316 .as_str(),
317 )
318 .bind(id)
319 .bind(&json_document)
320 .bind(&embedding_text)
321 .bind(&embedding)
322 .execute(&self.pg_pool)
323 .await
324 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
325 }
326 }
327
328 Ok(())
329 }
330}
331
332impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
333where
334 Model: EmbeddingModel,
335{
336 type Filter = PgSearchFilter;
337
338 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
341 &self,
342 req: VectorSearchRequest<PgSearchFilter>,
343 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
344 if req.samples() > i64::MAX as u64 {
345 return Err(VectorStoreError::DatastoreError(
346 format!(
347 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
348 i64::MAX
349 )
350 .into(),
351 ));
352 }
353
354 let embedded_query: pgvector::Vector = self
355 .model
356 .embed_text(req.query())
357 .await?
358 .vec
359 .iter()
360 .map(|&x| x as f32)
361 .collect::<Vec<f32>>()
362 .into();
363
364 let (search_query, params) = self.search_query_full(&req);
365 let builder = sqlx::query_as(search_query.as_str())
366 .bind(embedded_query)
367 .bind(req.samples() as i64);
368
369 let builder = params
370 .iter()
371 .fold(builder, |builder, param| builder.bind(param));
372
373 let rows = builder
374 .fetch_all(&self.pg_pool)
375 .await
376 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
377
378 let rows: Vec<(f64, String, T)> = rows
379 .into_iter()
380 .flat_map(SearchResult::into_result)
381 .collect();
382
383 Ok(rows)
384 }
385
386 async fn top_n_ids(
388 &self,
389 req: VectorSearchRequest<PgSearchFilter>,
390 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
391 if req.samples() > i64::MAX as u64 {
392 return Err(VectorStoreError::DatastoreError(
393 format!(
394 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
395 i64::MAX
396 )
397 .into(),
398 ));
399 }
400 let embedded_query: pgvector::Vector = self
401 .model
402 .embed_text(req.query())
403 .await?
404 .vec
405 .iter()
406 .map(|&x| x as f32)
407 .collect::<Vec<f32>>()
408 .into();
409
410 let (search_query, params) = self.search_query_only_ids(&req);
411 let builder = sqlx::query_as(search_query.as_str())
412 .bind(embedded_query)
413 .bind(req.samples() as i64);
414
415 let builder = params
416 .iter()
417 .fold(builder, |builder, param| builder.bind(param));
418
419 let rows: Vec<SearchResultOnlyId> = builder
420 .fetch_all(&self.pg_pool)
421 .await
422 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
423
424 let rows: Vec<(f64, String)> = rows
425 .into_iter()
426 .map(|row| (row.distance, row.id.to_string()))
427 .collect();
428
429 Ok(rows)
430 }
431}