1use std::{fmt::Display, ops::RangeInclusive};
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, VectorStoreError, VectorStoreIndex,
8 request::{SearchFilter, VectorSearchRequest},
9 },
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use serde_json::Value;
13use sqlx::{PgPool, Postgres, postgres::PgArguments, query::QueryAs};
14use uuid::Uuid;
15
16pub struct PostgresVectorStore<Model: EmbeddingModel> {
17 model: Model,
18 pg_pool: PgPool,
19 documents_table: String,
20 distance_function: PgVectorDistanceFunction,
21}
22
23pub enum PgVectorDistanceFunction {
32 L2,
33 InnerProduct,
34 Cosine,
35 L1,
36 Hamming,
37 Jaccard,
38}
39
40impl Display for PgVectorDistanceFunction {
41 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42 match self {
43 PgVectorDistanceFunction::L2 => write!(f, "<->"),
44 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
45 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
46 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
47 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
48 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
49 }
50 }
51}
52
53#[derive(Clone, Default, Serialize, Deserialize, Debug)]
54pub struct PgSearchFilter {
55 condition: String,
56 values: Vec<serde_json::Value>,
57}
58
59impl SearchFilter for PgSearchFilter {
60 type Value = serde_json::Value;
61
62 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
63 Self {
64 condition: format!("{} = $", key.as_ref()),
65 values: vec![value],
66 }
67 }
68
69 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
70 Self {
71 condition: format!("{} > $", key.as_ref()),
72 values: vec![value],
73 }
74 }
75
76 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
77 Self {
78 condition: format!("{} < $", key.as_ref()),
79 values: vec![value],
80 }
81 }
82
83 fn and(self, rhs: Self) -> Self {
84 Self {
85 condition: format!("({}) AND ({})", self.condition, rhs.condition),
86 values: self.values.into_iter().chain(rhs.values).collect(),
87 }
88 }
89
90 fn or(self, rhs: Self) -> Self {
91 Self {
92 condition: format!("({}) OR ({})", self.condition, rhs.condition),
93 values: self.values.into_iter().chain(rhs.values).collect(),
94 }
95 }
96}
97
98impl PgSearchFilter {
99 fn into_clause(self) -> (String, Vec<serde_json::Value>) {
100 (self.condition, self.values)
101 }
102
103 #[allow(clippy::should_implement_trait)]
104 pub fn not(self) -> Self {
105 Self {
106 condition: format!("NOT ({})", self.condition),
107 values: self.values,
108 }
109 }
110
111 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
112 Self {
113 condition: format!("{key} >= ?"),
114 values: vec![value],
115 }
116 }
117
118 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
119 Self {
120 condition: format!("{key} <= ?"),
121 values: vec![value],
122 }
123 }
124
125 pub fn is_null(key: String) -> Self {
126 Self {
127 condition: format!("{key} is null"),
128 ..Default::default()
129 }
130 }
131
132 pub fn is_not_null(key: String) -> Self {
133 Self {
134 condition: format!("{key} is not null"),
135 ..Default::default()
136 }
137 }
138
139 pub fn between<T>(key: String, range: RangeInclusive<T>) -> Self
140 where
141 T: std::fmt::Display + Into<serde_json::Number> + Copy,
142 {
143 let lo = range.start();
144 let hi = range.end();
145
146 Self {
147 condition: format!("{key} between {lo} and {hi}"),
148 ..Default::default()
149 }
150 }
151
152 pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153 let placeholders = values.iter().map(|_| "?").collect::<Vec<&str>>().join(",");
154
155 Self {
156 condition: format!("{key} is in ({placeholders})"),
157 values,
158 }
159 }
160
161 pub fn like(key: String, pattern: &'static str) -> Self {
166 Self {
167 condition: format!("{key} like {pattern}"),
168 ..Default::default()
169 }
170 }
171
172 pub fn similar_to(key: String, pattern: &'static str) -> Self {
175 Self {
176 condition: format!("{key} similar to {pattern}"),
177 ..Default::default()
178 }
179 }
180}
181
182fn bind_value<S>(
183 builder: QueryAs<'_, Postgres, S, PgArguments>,
184 value: Value,
185) -> QueryAs<'_, Postgres, S, PgArguments> {
186 match value {
187 Value::Null => builder.bind(Option::<String>::None),
188 Value::Bool(b) => builder.bind(b),
189 Value::Number(num) => {
190 if let Some(n) = num.as_f64() {
191 builder.bind(n)
192 } else if let Some(n) = num.as_i64() {
193 builder.bind(n)
194 } else if let Some(n) = num.as_u64() {
195 builder.bind(n as i64)
196 } else {
197 builder.bind(num.to_string())
198 }
199 }
200 Value::String(s) => builder.bind(s),
201 Value::Array(xs) => {
202 if let Some(xs) = xs
203 .iter()
204 .map(|v| v.as_str().map(str::to_string))
205 .collect::<Option<Vec<_>>>()
206 {
207 builder.bind(xs)
208 } else if let Some(xs) = xs.iter().map(Value::as_f64).collect::<Option<Vec<_>>>() {
209 builder.bind(xs)
210 } else if let Some(xs) = xs.iter().map(Value::as_i64).collect::<Option<Vec<_>>>() {
211 builder.bind(xs)
212 } else if let Some(xs) = xs.iter().map(Value::as_bool).collect::<Option<Vec<_>>>() {
213 builder.bind(xs)
214 } else {
215 builder.bind(Value::Array(xs))
216 }
217 }
218 object => builder.bind(object),
220 }
221}
222
223#[derive(Debug, Deserialize, sqlx::FromRow)]
224pub struct SearchResult {
225 id: Uuid,
226 document: Value,
227 distance: f64,
229}
230
231#[derive(Debug, Deserialize, sqlx::FromRow)]
232pub struct SearchResultOnlyId {
233 id: Uuid,
234 distance: f64,
235}
236
237impl SearchResult {
238 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
239 let document: T =
240 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
241 Ok((self.distance, self.id.to_string(), document))
242 }
243}
244
245impl<Model> PostgresVectorStore<Model>
246where
247 Model: EmbeddingModel,
248{
249 pub fn new(
250 model: Model,
251 pg_pool: PgPool,
252 documents_table: Option<String>,
253 distance_function: PgVectorDistanceFunction,
254 ) -> Self {
255 Self {
256 model,
257 pg_pool,
258 documents_table: documents_table.unwrap_or(String::from("documents")),
259 distance_function,
260 }
261 }
262
263 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
264 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
265 }
266
267 fn search_query_full(
268 &self,
269 req: &VectorSearchRequest<PgSearchFilter>,
270 ) -> (String, Vec<serde_json::Value>) {
271 self.search_query(true, req)
272 }
273
274 fn search_query_only_ids(
275 &self,
276 req: &VectorSearchRequest<PgSearchFilter>,
277 ) -> (String, Vec<serde_json::Value>) {
278 self.search_query(false, req)
279 }
280
281 fn search_query(
282 &self,
283 with_document: bool,
284 req: &VectorSearchRequest<PgSearchFilter>,
285 ) -> (String, Vec<serde_json::Value>) {
286 let document = if with_document { ", document" } else { "" };
287
288 let thresh = req
289 .threshold()
290 .map(|t| PgSearchFilter::gt("distance", t.into()));
291 let filter = match (thresh, req.filter()) {
292 (Some(thresh), Some(filt)) => Some(thresh.and(filt.clone())),
293 (Some(thresh), _) => Some(thresh),
294 (_, Some(filt)) => Some(filt.clone()),
295 _ => None,
296 };
297 let (where_clause, params) = match filter {
298 Some(f) => {
299 let (expr, params) = f.into_clause();
300 (String::from("WHERE") + &expr, params)
301 }
302 None => (Default::default(), Default::default()),
303 };
304
305 let mut counter = 3;
306 let mut buf = String::with_capacity(where_clause.len() * 2);
307
308 for c in where_clause.chars() {
309 buf.push(c);
310
311 if c == '$' {
312 buf.push_str(counter.to_string().as_str());
313 counter += 1;
314 }
315 }
316
317 let where_clause = buf;
318
319 let query = format!(
320 "
321 SELECT id{}, distance FROM ( \
322 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
323 FROM {} \
324 {where_clause} \
325 ORDER BY id, distance \
326 ) as d \
327 ORDER BY distance \
328 LIMIT $2",
329 document, document, self.distance_function, self.documents_table
330 );
331
332 (query, params)
333 }
334}
335
336impl<Model> InsertDocuments for PostgresVectorStore<Model>
337where
338 Model: EmbeddingModel + Send + Sync,
339{
340 async fn insert_documents<Doc: Serialize + Embed + Send>(
341 &self,
342 documents: Vec<(Doc, OneOrMany<Embedding>)>,
343 ) -> Result<(), VectorStoreError> {
344 for (document, embeddings) in documents {
345 let id = Uuid::new_v4();
346 let json_document = serde_json::to_value(&document)?;
347
348 for embedding in embeddings {
349 let embedding_text = embedding.document;
350 let embedding: Vec<f64> = embedding.vec;
351
352 sqlx::query(
353 format!(
354 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
355 self.documents_table
356 )
357 .as_str(),
358 )
359 .bind(id)
360 .bind(&json_document)
361 .bind(&embedding_text)
362 .bind(&embedding)
363 .execute(&self.pg_pool)
364 .await
365 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
366 }
367 }
368
369 Ok(())
370 }
371}
372
373impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
374where
375 Model: EmbeddingModel,
376{
377 type Filter = PgSearchFilter;
378
379 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
382 &self,
383 req: VectorSearchRequest<PgSearchFilter>,
384 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
385 if req.samples() > i64::MAX as u64 {
386 return Err(VectorStoreError::DatastoreError(
387 format!(
388 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
389 i64::MAX
390 )
391 .into(),
392 ));
393 }
394
395 let embedded_query: pgvector::Vector = self
396 .model
397 .embed_text(req.query())
398 .await?
399 .vec
400 .iter()
401 .map(|&x| x as f32)
402 .collect::<Vec<f32>>()
403 .into();
404
405 let (search_query, params) = self.search_query_full(&req);
406 let builder = sqlx::query_as(search_query.as_str())
407 .bind(embedded_query)
408 .bind(req.samples() as i64);
409
410 let builder = params.iter().cloned().fold(builder, bind_value);
411
412 let rows = builder
413 .fetch_all(&self.pg_pool)
414 .await
415 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
416
417 let rows: Vec<(f64, String, T)> = rows
418 .into_iter()
419 .flat_map(SearchResult::into_result)
420 .collect();
421
422 Ok(rows)
423 }
424
425 async fn top_n_ids(
427 &self,
428 req: VectorSearchRequest<PgSearchFilter>,
429 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
430 if req.samples() > i64::MAX as u64 {
431 return Err(VectorStoreError::DatastoreError(
432 format!(
433 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
434 i64::MAX
435 )
436 .into(),
437 ));
438 }
439 let embedded_query: pgvector::Vector = self
440 .model
441 .embed_text(req.query())
442 .await?
443 .vec
444 .iter()
445 .map(|&x| x as f32)
446 .collect::<Vec<f32>>()
447 .into();
448
449 let (search_query, params) = self.search_query_only_ids(&req);
450 let builder = sqlx::query_as(search_query.as_str())
451 .bind(embedded_query)
452 .bind(req.samples() as i64);
453
454 let builder = params.iter().cloned().fold(builder, bind_value);
455
456 let rows: Vec<SearchResultOnlyId> = builder
457 .fetch_all(&self.pg_pool)
458 .await
459 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
460
461 let rows: Vec<(f64, String)> = rows
462 .into_iter()
463 .map(|row| (row.distance, row.id.to_string()))
464 .collect();
465
466 Ok(rows)
467 }
468}