1use std::{fmt::Display, ops::RangeInclusive};
12
13use rig_core::{
14 Embed, OneOrMany,
15 embeddings::{Embedding, EmbeddingModel},
16 vector_store::{
17 InsertDocuments, VectorStoreError, VectorStoreIndex,
18 request::{SearchFilter, VectorSearchRequest},
19 },
20};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use serde_json::Value;
23use sqlx::{PgPool, Postgres, postgres::PgArguments, query::QueryAs};
24use uuid::Uuid;
25
26pub struct PostgresVectorStore<Model: EmbeddingModel> {
27 model: Model,
28 pg_pool: PgPool,
29 documents_table: String,
30 distance_function: PgVectorDistanceFunction,
31}
32
33pub enum PgVectorDistanceFunction {
42 L2,
43 InnerProduct,
44 Cosine,
45 L1,
46 Hamming,
47 Jaccard,
48}
49
50impl Display for PgVectorDistanceFunction {
51 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52 match self {
53 PgVectorDistanceFunction::L2 => write!(f, "<->"),
54 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
55 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
56 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
57 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
58 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
59 }
60 }
61}
62
63#[derive(Clone, Default, Serialize, Deserialize, Debug)]
64pub struct PgSearchFilter {
65 condition: String,
66 values: Vec<serde_json::Value>,
67}
68
69impl SearchFilter for PgSearchFilter {
70 type Value = serde_json::Value;
71
72 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
73 Self {
74 condition: format!("{} = $", key.as_ref()),
75 values: vec![value],
76 }
77 }
78
79 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
80 Self {
81 condition: format!("{} > $", key.as_ref()),
82 values: vec![value],
83 }
84 }
85
86 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
87 Self {
88 condition: format!("{} < $", key.as_ref()),
89 values: vec![value],
90 }
91 }
92
93 fn and(self, rhs: Self) -> Self {
94 Self {
95 condition: format!("({}) AND ({})", self.condition, rhs.condition),
96 values: self.values.into_iter().chain(rhs.values).collect(),
97 }
98 }
99
100 fn or(self, rhs: Self) -> Self {
101 Self {
102 condition: format!("({}) OR ({})", self.condition, rhs.condition),
103 values: self.values.into_iter().chain(rhs.values).collect(),
104 }
105 }
106}
107
108impl PgSearchFilter {
109 fn into_clause(self) -> (String, Vec<serde_json::Value>) {
110 (self.condition, self.values)
111 }
112
113 #[allow(clippy::should_implement_trait)]
114 pub fn not(self) -> Self {
115 Self {
116 condition: format!("NOT ({})", self.condition),
117 values: self.values,
118 }
119 }
120
121 pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
122 Self {
123 condition: format!("{key} >= ?"),
124 values: vec![value],
125 }
126 }
127
128 pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
129 Self {
130 condition: format!("{key} <= ?"),
131 values: vec![value],
132 }
133 }
134
135 pub fn is_null(key: String) -> Self {
136 Self {
137 condition: format!("{key} is null"),
138 ..Default::default()
139 }
140 }
141
142 pub fn is_not_null(key: String) -> Self {
143 Self {
144 condition: format!("{key} is not null"),
145 ..Default::default()
146 }
147 }
148
149 pub fn between<T>(key: String, range: RangeInclusive<T>) -> Self
150 where
151 T: std::fmt::Display + Into<serde_json::Number> + Copy,
152 {
153 let lo = range.start();
154 let hi = range.end();
155
156 Self {
157 condition: format!("{key} between {lo} and {hi}"),
158 ..Default::default()
159 }
160 }
161
162 pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
163 let placeholders = values.iter().map(|_| "?").collect::<Vec<&str>>().join(",");
164
165 Self {
166 condition: format!("{key} is in ({placeholders})"),
167 values,
168 }
169 }
170
171 pub fn like(key: String, pattern: &'static str) -> Self {
176 Self {
177 condition: format!("{key} like {pattern}"),
178 ..Default::default()
179 }
180 }
181
182 pub fn similar_to(key: String, pattern: &'static str) -> Self {
185 Self {
186 condition: format!("{key} similar to {pattern}"),
187 ..Default::default()
188 }
189 }
190}
191
192fn bind_value<S>(
193 builder: QueryAs<'_, Postgres, S, PgArguments>,
194 value: Value,
195) -> QueryAs<'_, Postgres, S, PgArguments> {
196 match value {
197 Value::Null => builder.bind(Option::<String>::None),
198 Value::Bool(b) => builder.bind(b),
199 Value::Number(num) => {
200 if let Some(n) = num.as_f64() {
201 builder.bind(n)
202 } else if let Some(n) = num.as_i64() {
203 builder.bind(n)
204 } else if let Some(n) = num.as_u64() {
205 builder.bind(n as i64)
206 } else {
207 builder.bind(num.to_string())
208 }
209 }
210 Value::String(s) => builder.bind(s),
211 Value::Array(xs) => {
212 if let Some(xs) = xs
213 .iter()
214 .map(|v| v.as_str().map(str::to_string))
215 .collect::<Option<Vec<_>>>()
216 {
217 builder.bind(xs)
218 } else if let Some(xs) = xs.iter().map(Value::as_f64).collect::<Option<Vec<_>>>() {
219 builder.bind(xs)
220 } else if let Some(xs) = xs.iter().map(Value::as_i64).collect::<Option<Vec<_>>>() {
221 builder.bind(xs)
222 } else if let Some(xs) = xs.iter().map(Value::as_bool).collect::<Option<Vec<_>>>() {
223 builder.bind(xs)
224 } else {
225 builder.bind(Value::Array(xs))
226 }
227 }
228 object => builder.bind(object),
230 }
231}
232
233#[derive(Debug, Deserialize, sqlx::FromRow)]
234pub struct SearchResult {
235 id: Uuid,
236 document: Value,
237 distance: f64,
239}
240
241#[derive(Debug, Deserialize, sqlx::FromRow)]
242pub struct SearchResultOnlyId {
243 id: Uuid,
244 distance: f64,
245}
246
247impl SearchResult {
248 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
249 let document: T =
250 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
251 Ok((self.distance, self.id.to_string(), document))
252 }
253}
254
255impl<Model> PostgresVectorStore<Model>
256where
257 Model: EmbeddingModel,
258{
259 pub fn new(
260 model: Model,
261 pg_pool: PgPool,
262 documents_table: Option<String>,
263 distance_function: PgVectorDistanceFunction,
264 ) -> Self {
265 Self {
266 model,
267 pg_pool,
268 documents_table: documents_table.unwrap_or(String::from("documents")),
269 distance_function,
270 }
271 }
272
273 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
274 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
275 }
276
277 fn search_query_full(
278 &self,
279 req: &VectorSearchRequest<PgSearchFilter>,
280 ) -> (String, Vec<serde_json::Value>) {
281 self.search_query(true, req)
282 }
283
284 fn search_query_only_ids(
285 &self,
286 req: &VectorSearchRequest<PgSearchFilter>,
287 ) -> (String, Vec<serde_json::Value>) {
288 self.search_query(false, req)
289 }
290
291 fn search_query(
292 &self,
293 with_document: bool,
294 req: &VectorSearchRequest<PgSearchFilter>,
295 ) -> (String, Vec<serde_json::Value>) {
296 let document = if with_document { ", document" } else { "" };
297
298 let thresh = req
299 .threshold()
300 .map(|t| PgSearchFilter::gt("distance", t.into()));
301 let filter = match (thresh, req.filter()) {
302 (Some(thresh), Some(filt)) => Some(thresh.and(filt.clone())),
303 (Some(thresh), _) => Some(thresh),
304 (_, Some(filt)) => Some(filt.clone()),
305 _ => None,
306 };
307 let (where_clause, params) = match filter {
308 Some(f) => {
309 let (expr, params) = f.into_clause();
310 (String::from("WHERE") + &expr, params)
311 }
312 None => (Default::default(), Default::default()),
313 };
314
315 let mut counter = 3;
316 let mut buf = String::with_capacity(where_clause.len() * 2);
317
318 for c in where_clause.chars() {
319 buf.push(c);
320
321 if c == '$' {
322 buf.push_str(counter.to_string().as_str());
323 counter += 1;
324 }
325 }
326
327 let where_clause = buf;
328
329 let query = format!(
330 "
331 SELECT id{}, distance FROM ( \
332 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
333 FROM {} \
334 {where_clause} \
335 ORDER BY id, distance \
336 ) as d \
337 ORDER BY distance \
338 LIMIT $2",
339 document, document, self.distance_function, self.documents_table
340 );
341
342 (query, params)
343 }
344}
345
346impl<Model> InsertDocuments for PostgresVectorStore<Model>
347where
348 Model: EmbeddingModel + Send + Sync,
349{
350 async fn insert_documents<Doc: Serialize + Embed + Send>(
351 &self,
352 documents: Vec<(Doc, OneOrMany<Embedding>)>,
353 ) -> Result<(), VectorStoreError> {
354 for (document, embeddings) in documents {
355 let id = Uuid::new_v4();
356 let json_document = serde_json::to_value(&document)?;
357
358 for embedding in embeddings {
359 let embedding_text = embedding.document;
360 let embedding: Vec<f64> = embedding.vec;
361
362 sqlx::query(
363 format!(
364 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
365 self.documents_table
366 )
367 .as_str(),
368 )
369 .bind(id)
370 .bind(&json_document)
371 .bind(&embedding_text)
372 .bind(&embedding)
373 .execute(&self.pg_pool)
374 .await
375 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
376 }
377 }
378
379 Ok(())
380 }
381}
382
383impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
384where
385 Model: EmbeddingModel,
386{
387 type Filter = PgSearchFilter;
388
389 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
392 &self,
393 req: VectorSearchRequest<PgSearchFilter>,
394 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
395 if req.samples() > i64::MAX as u64 {
396 return Err(VectorStoreError::DatastoreError(
397 format!(
398 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
399 i64::MAX
400 )
401 .into(),
402 ));
403 }
404
405 let embedded_query: pgvector::Vector = self
406 .model
407 .embed_text(req.query())
408 .await?
409 .vec
410 .iter()
411 .map(|&x| x as f32)
412 .collect::<Vec<f32>>()
413 .into();
414
415 let (search_query, params) = self.search_query_full(&req);
416 let builder = sqlx::query_as(search_query.as_str())
417 .bind(embedded_query)
418 .bind(req.samples() as i64);
419
420 let builder = params.iter().cloned().fold(builder, bind_value);
421
422 let rows = builder
423 .fetch_all(&self.pg_pool)
424 .await
425 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
426
427 let rows: Vec<(f64, String, T)> = rows
428 .into_iter()
429 .flat_map(SearchResult::into_result)
430 .collect();
431
432 Ok(rows)
433 }
434
435 async fn top_n_ids(
437 &self,
438 req: VectorSearchRequest<PgSearchFilter>,
439 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
440 if req.samples() > i64::MAX as u64 {
441 return Err(VectorStoreError::DatastoreError(
442 format!(
443 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
444 i64::MAX
445 )
446 .into(),
447 ));
448 }
449 let embedded_query: pgvector::Vector = self
450 .model
451 .embed_text(req.query())
452 .await?
453 .vec
454 .iter()
455 .map(|&x| x as f32)
456 .collect::<Vec<f32>>()
457 .into();
458
459 let (search_query, params) = self.search_query_only_ids(&req);
460 let builder = sqlx::query_as(search_query.as_str())
461 .bind(embedded_query)
462 .bind(req.samples() as i64);
463
464 let builder = params.iter().cloned().fold(builder, bind_value);
465
466 let rows: Vec<SearchResultOnlyId> = builder
467 .fetch_all(&self.pg_pool)
468 .await
469 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
470
471 let rows: Vec<(f64, String)> = rows
472 .into_iter()
473 .map(|row| (row.distance, row.id.to_string()))
474 .collect();
475
476 Ok(rows)
477 }
478}