1use std::{fmt::Display, ops::RangeInclusive};
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, VectorStoreError, VectorStoreIndex,
8 request::{SearchFilter, VectorSearchRequest},
9 },
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use serde_json::Value;
13use sqlx::{PgPool, Postgres, postgres::PgArguments, query::QueryAs};
14use uuid::Uuid;
15
16pub struct PostgresVectorStore<Model: EmbeddingModel> {
17 model: Model,
18 pg_pool: PgPool,
19 documents_table: String,
20 distance_function: PgVectorDistanceFunction,
21}
22
23pub enum PgVectorDistanceFunction {
32 L2,
33 InnerProduct,
34 Cosine,
35 L1,
36 Hamming,
37 Jaccard,
38}
39
40impl Display for PgVectorDistanceFunction {
41 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42 match self {
43 PgVectorDistanceFunction::L2 => write!(f, "<->"),
44 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
45 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
46 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
47 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
48 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
49 }
50 }
51}
52
53#[derive(Clone, Default, Serialize, Deserialize, Debug)]
54pub struct PgSearchFilter {
55 condition: String,
56 values: Vec<serde_json::Value>,
57}
58
59impl SearchFilter for PgSearchFilter {
60 type Value = serde_json::Value;
61
62 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
63 Self {
64 condition: format!("{} = $", key.as_ref()),
65 values: vec![value],
66 }
67 }
68
69 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
70 Self {
71 condition: format!("{} > $", key.as_ref()),
72 values: vec![value],
73 }
74 }
75
76 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
77 Self {
78 condition: format!("{} < $", key.as_ref()),
79 values: vec![value],
80 }
81 }
82
83 fn and(self, rhs: Self) -> Self {
84 Self {
85 condition: format!("({}) AND ({})", self.condition, rhs.condition),
86 values: self.values.into_iter().chain(rhs.values).collect(),
87 }
88 }
89
90 fn or(self, rhs: Self) -> Self {
91 Self {
92 condition: format!("({}) OR ({})", self.condition, rhs.condition),
93 values: self.values.into_iter().chain(rhs.values).collect(),
94 }
95 }
96}
97
98impl PgSearchFilter {
99 fn into_clause(self) -> (String, Vec<serde_json::Value>) {
100 (self.condition, self.values)
101 }
102
103 #[allow(clippy::should_implement_trait)]
104 pub fn not(self) -> Self {
105 Self {
106 condition: format!("NOT ({})", self.condition),
107 values: self.values,
108 }
109 }
110
111 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
112 Self {
113 condition: format!("{key} >= ?"),
114 values: vec![value],
115 }
116 }
117
118 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
119 Self {
120 condition: format!("{key} <= ?"),
121 values: vec![value],
122 }
123 }
124
125 pub fn is_null(key: String) -> Self {
126 Self {
127 condition: format!("{key} is null"),
128 ..Default::default()
129 }
130 }
131
132 pub fn is_not_null(key: String) -> Self {
133 Self {
134 condition: format!("{key} is not null"),
135 ..Default::default()
136 }
137 }
138
139 pub fn between<T>(key: String, range: RangeInclusive<T>) -> Self
140 where
141 T: std::fmt::Display + Into<serde_json::Number> + Copy,
142 {
143 let lo = range.start();
144 let hi = range.end();
145
146 Self {
147 condition: format!("{key} between {lo} and {hi}"),
148 ..Default::default()
149 }
150 }
151
152 pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153 let placeholders = values.iter().map(|_| "?").collect::<Vec<&str>>().join(",");
154
155 Self {
156 condition: format!("{key} is in ({placeholders})"),
157 values,
158 }
159 }
160
161 pub fn like(key: String, pattern: &'static str) -> Self {
166 Self {
167 condition: format!("{key} like {pattern}"),
168 ..Default::default()
169 }
170 }
171
172 pub fn similar_to(key: String, pattern: &'static str) -> Self {
175 Self {
176 condition: format!("{key} similar to {pattern}"),
177 ..Default::default()
178 }
179 }
180}
181
182fn bind_value<S>(
183 builder: QueryAs<'_, Postgres, S, PgArguments>,
184 value: Value,
185) -> QueryAs<'_, Postgres, S, PgArguments> {
186 match value {
187 Value::Null => unreachable!(),
188 Value::Bool(b) => builder.bind(b),
189 Value::Number(num) => {
190 if let Some(n) = num.as_f64() {
191 builder.bind(n)
192 } else if let Some(n) = num.as_i64() {
193 builder.bind(n)
194 } else {
195 unreachable!()
196 }
197 }
198 Value::String(s) => builder.bind(s),
199 Value::Array(xs) => {
200 if let Some(xs) = xs
201 .iter()
202 .map(|v| v.as_str().map(str::to_string))
203 .collect::<Option<Vec<_>>>()
204 {
205 builder.bind(xs)
206 } else if let Some(xs) = xs.iter().map(Value::as_f64).collect::<Option<Vec<_>>>() {
207 builder.bind(xs)
208 } else if let Some(xs) = xs.iter().map(Value::as_i64).collect::<Option<Vec<_>>>() {
209 builder.bind(xs)
210 } else if let Some(xs) = xs.iter().map(Value::as_bool).collect::<Option<Vec<_>>>() {
211 builder.bind(xs)
212 } else {
213 builder.bind(Value::Array(xs))
214 }
215 }
216 object => builder.bind(object),
218 }
219}
220
221#[derive(Debug, Deserialize, sqlx::FromRow)]
222pub struct SearchResult {
223 id: Uuid,
224 document: Value,
225 distance: f64,
227}
228
229#[derive(Debug, Deserialize, sqlx::FromRow)]
230pub struct SearchResultOnlyId {
231 id: Uuid,
232 distance: f64,
233}
234
235impl SearchResult {
236 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
237 let document: T =
238 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
239 Ok((self.distance, self.id.to_string(), document))
240 }
241}
242
243impl<Model> PostgresVectorStore<Model>
244where
245 Model: EmbeddingModel,
246{
247 pub fn new(
248 model: Model,
249 pg_pool: PgPool,
250 documents_table: Option<String>,
251 distance_function: PgVectorDistanceFunction,
252 ) -> Self {
253 Self {
254 model,
255 pg_pool,
256 documents_table: documents_table.unwrap_or(String::from("documents")),
257 distance_function,
258 }
259 }
260
261 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
262 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
263 }
264
265 fn search_query_full(
266 &self,
267 req: &VectorSearchRequest<PgSearchFilter>,
268 ) -> (String, Vec<serde_json::Value>) {
269 self.search_query(true, req)
270 }
271
272 fn search_query_only_ids(
273 &self,
274 req: &VectorSearchRequest<PgSearchFilter>,
275 ) -> (String, Vec<serde_json::Value>) {
276 self.search_query(false, req)
277 }
278
279 fn search_query(
280 &self,
281 with_document: bool,
282 req: &VectorSearchRequest<PgSearchFilter>,
283 ) -> (String, Vec<serde_json::Value>) {
284 let document = if with_document { ", document" } else { "" };
285
286 let thresh = req
287 .threshold()
288 .map(|t| PgSearchFilter::gt("distance", t.into()));
289 let filter = match (thresh, req.filter()) {
290 (Some(thresh), Some(filt)) => Some(thresh.and(filt.clone())),
291 (Some(thresh), _) => Some(thresh),
292 (_, Some(filt)) => Some(filt.clone()),
293 _ => None,
294 };
295 let (where_clause, params) = match filter {
296 Some(f) => {
297 let (expr, params) = f.into_clause();
298 (String::from("WHERE") + &expr, params)
299 }
300 None => (Default::default(), Default::default()),
301 };
302
303 let mut counter = 3;
304 let mut buf = String::with_capacity(where_clause.len() * 2);
305
306 for c in where_clause.chars() {
307 buf.push(c);
308
309 if c == '$' {
310 buf.push_str(counter.to_string().as_str());
311 counter += 1;
312 }
313 }
314
315 let where_clause = buf;
316
317 let query = format!(
318 "
319 SELECT id{}, distance FROM ( \
320 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
321 FROM {} \
322 {where_clause} \
323 ORDER BY id, distance \
324 ) as d \
325 ORDER BY distance \
326 LIMIT $2",
327 document, document, self.distance_function, self.documents_table
328 );
329
330 (query, params)
331 }
332}
333
334impl<Model> InsertDocuments for PostgresVectorStore<Model>
335where
336 Model: EmbeddingModel + Send + Sync,
337{
338 async fn insert_documents<Doc: Serialize + Embed + Send>(
339 &self,
340 documents: Vec<(Doc, OneOrMany<Embedding>)>,
341 ) -> Result<(), VectorStoreError> {
342 for (document, embeddings) in documents {
343 let id = Uuid::new_v4();
344 let json_document = serde_json::to_value(&document).unwrap();
345
346 for embedding in embeddings {
347 let embedding_text = embedding.document;
348 let embedding: Vec<f64> = embedding.vec;
349
350 sqlx::query(
351 format!(
352 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
353 self.documents_table
354 )
355 .as_str(),
356 )
357 .bind(id)
358 .bind(&json_document)
359 .bind(&embedding_text)
360 .bind(&embedding)
361 .execute(&self.pg_pool)
362 .await
363 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
364 }
365 }
366
367 Ok(())
368 }
369}
370
371impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
372where
373 Model: EmbeddingModel,
374{
375 type Filter = PgSearchFilter;
376
377 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
380 &self,
381 req: VectorSearchRequest<PgSearchFilter>,
382 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
383 if req.samples() > i64::MAX as u64 {
384 return Err(VectorStoreError::DatastoreError(
385 format!(
386 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
387 i64::MAX
388 )
389 .into(),
390 ));
391 }
392
393 let embedded_query: pgvector::Vector = self
394 .model
395 .embed_text(req.query())
396 .await?
397 .vec
398 .iter()
399 .map(|&x| x as f32)
400 .collect::<Vec<f32>>()
401 .into();
402
403 let (search_query, params) = self.search_query_full(&req);
404 let builder = sqlx::query_as(search_query.as_str())
405 .bind(embedded_query)
406 .bind(req.samples() as i64);
407
408 let builder = params.iter().cloned().fold(builder, bind_value);
409
410 let rows = builder
411 .fetch_all(&self.pg_pool)
412 .await
413 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
414
415 let rows: Vec<(f64, String, T)> = rows
416 .into_iter()
417 .flat_map(SearchResult::into_result)
418 .collect();
419
420 Ok(rows)
421 }
422
423 async fn top_n_ids(
425 &self,
426 req: VectorSearchRequest<PgSearchFilter>,
427 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
428 if req.samples() > i64::MAX as u64 {
429 return Err(VectorStoreError::DatastoreError(
430 format!(
431 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
432 i64::MAX
433 )
434 .into(),
435 ));
436 }
437 let embedded_query: pgvector::Vector = self
438 .model
439 .embed_text(req.query())
440 .await?
441 .vec
442 .iter()
443 .map(|&x| x as f32)
444 .collect::<Vec<f32>>()
445 .into();
446
447 let (search_query, params) = self.search_query_only_ids(&req);
448 let builder = sqlx::query_as(search_query.as_str())
449 .bind(embedded_query)
450 .bind(req.samples() as i64);
451
452 let builder = params.iter().cloned().fold(builder, bind_value);
453
454 let rows: Vec<SearchResultOnlyId> = builder
455 .fetch_all(&self.pg_pool)
456 .await
457 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
458
459 let rows: Vec<(f64, String)> = rows
460 .into_iter()
461 .map(|row| (row.distance, row.id.to_string()))
462 .collect();
463
464 Ok(rows)
465 }
466}