Skip to main content

rig_postgres/
lib.rs

1//! PostgreSQL and pgvector integration for Rig.
2//!
3//! This crate provides [`PostgresVectorStore`], a Rig vector store backed by a
4//! PostgreSQL table with a `pgvector` embedding column. It supports the distance
5//! functions represented by [`PgVectorDistanceFunction`] and query filters via
6//! [`PgSearchFilter`].
7//!
8//! The root `rig` facade re-exports this crate as `rig::postgres` when the
9//! `postgres` feature is enabled.
10
11use std::{fmt::Display, ops::RangeInclusive};
12
13use rig_core::{
14    Embed, OneOrMany,
15    embeddings::{Embedding, EmbeddingModel},
16    vector_store::{
17        InsertDocuments, VectorStoreError, VectorStoreIndex,
18        request::{SearchFilter, VectorSearchRequest},
19    },
20};
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use serde_json::Value;
23use sqlx::{PgPool, Postgres, postgres::PgArguments, query::QueryAs};
24use uuid::Uuid;
25
26pub struct PostgresVectorStore<Model: EmbeddingModel> {
27    model: Model,
28    pg_pool: PgPool,
29    documents_table: String,
30    distance_function: PgVectorDistanceFunction,
31}
32
33/* PgVector supported distances
34<-> - L2 distance
35<#> - (negative) inner product
36<=> - cosine distance
37<+> - L1 distance (added in 0.7.0)
38<~> - Hamming distance (binary vectors, added in 0.7.0)
39<%> - Jaccard distance (binary vectors, added in 0.7.0)
40 */
41pub enum PgVectorDistanceFunction {
42    L2,
43    InnerProduct,
44    Cosine,
45    L1,
46    Hamming,
47    Jaccard,
48}
49
50impl Display for PgVectorDistanceFunction {
51    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52        match self {
53            PgVectorDistanceFunction::L2 => write!(f, "<->"),
54            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
55            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
56            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
57            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
58            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
59        }
60    }
61}
62
63#[derive(Clone, Default, Serialize, Deserialize, Debug)]
64pub struct PgSearchFilter {
65    condition: String,
66    values: Vec<serde_json::Value>,
67}
68
69impl SearchFilter for PgSearchFilter {
70    type Value = serde_json::Value;
71
72    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
73        Self {
74            condition: format!("{} = $", key.as_ref()),
75            values: vec![value],
76        }
77    }
78
79    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
80        Self {
81            condition: format!("{} > $", key.as_ref()),
82            values: vec![value],
83        }
84    }
85
86    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
87        Self {
88            condition: format!("{} < $", key.as_ref()),
89            values: vec![value],
90        }
91    }
92
93    fn and(self, rhs: Self) -> Self {
94        Self {
95            condition: format!("({}) AND ({})", self.condition, rhs.condition),
96            values: self.values.into_iter().chain(rhs.values).collect(),
97        }
98    }
99
100    fn or(self, rhs: Self) -> Self {
101        Self {
102            condition: format!("({}) OR ({})", self.condition, rhs.condition),
103            values: self.values.into_iter().chain(rhs.values).collect(),
104        }
105    }
106}
107
108impl PgSearchFilter {
109    fn into_clause(self) -> (String, Vec<serde_json::Value>) {
110        (self.condition, self.values)
111    }
112
113    #[allow(clippy::should_implement_trait)]
114    pub fn not(self) -> Self {
115        Self {
116            condition: format!("NOT ({})", self.condition),
117            values: self.values,
118        }
119    }
120
121    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
122        Self {
123            condition: format!("{key} >= ?"),
124            values: vec![value],
125        }
126    }
127
128    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
129        Self {
130            condition: format!("{key} <= ?"),
131            values: vec![value],
132        }
133    }
134
135    pub fn is_null(key: String) -> Self {
136        Self {
137            condition: format!("{key} is null"),
138            ..Default::default()
139        }
140    }
141
142    pub fn is_not_null(key: String) -> Self {
143        Self {
144            condition: format!("{key} is not null"),
145            ..Default::default()
146        }
147    }
148
149    pub fn between<T>(key: String, range: RangeInclusive<T>) -> Self
150    where
151        T: std::fmt::Display + Into<serde_json::Number> + Copy,
152    {
153        let lo = range.start();
154        let hi = range.end();
155
156        Self {
157            condition: format!("{key} between {lo} and {hi}"),
158            ..Default::default()
159        }
160    }
161
162    pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
163        let placeholders = values.iter().map(|_| "?").collect::<Vec<&str>>().join(",");
164
165        Self {
166            condition: format!("{key} is in ({placeholders})"),
167            values,
168        }
169    }
170
171    // String matching ops
172
173    /// Tests whether the value at `key` matches the (case-sensitive) pattern
174    /// `pattern` should be a valid SQL string pattern, with '%' and '_' as wildcards
175    pub fn like(key: String, pattern: &'static str) -> Self {
176        Self {
177            condition: format!("{key} like {pattern}"),
178            ..Default::default()
179        }
180    }
181
182    /// Tests whether the value at `key` matches the SQL regex pattern
183    /// `pattern` should be a valid regex
184    pub fn similar_to(key: String, pattern: &'static str) -> Self {
185        Self {
186            condition: format!("{key} similar to {pattern}"),
187            ..Default::default()
188        }
189    }
190}
191
192fn bind_value<S>(
193    builder: QueryAs<'_, Postgres, S, PgArguments>,
194    value: Value,
195) -> QueryAs<'_, Postgres, S, PgArguments> {
196    match value {
197        Value::Null => builder.bind(Option::<String>::None),
198        Value::Bool(b) => builder.bind(b),
199        Value::Number(num) => {
200            if let Some(n) = num.as_f64() {
201                builder.bind(n)
202            } else if let Some(n) = num.as_i64() {
203                builder.bind(n)
204            } else if let Some(n) = num.as_u64() {
205                builder.bind(n as i64)
206            } else {
207                builder.bind(num.to_string())
208            }
209        }
210        Value::String(s) => builder.bind(s),
211        Value::Array(xs) => {
212            if let Some(xs) = xs
213                .iter()
214                .map(|v| v.as_str().map(str::to_string))
215                .collect::<Option<Vec<_>>>()
216            {
217                builder.bind(xs)
218            } else if let Some(xs) = xs.iter().map(Value::as_f64).collect::<Option<Vec<_>>>() {
219                builder.bind(xs)
220            } else if let Some(xs) = xs.iter().map(Value::as_i64).collect::<Option<Vec<_>>>() {
221                builder.bind(xs)
222            } else if let Some(xs) = xs.iter().map(Value::as_bool).collect::<Option<Vec<_>>>() {
223                builder.bind(xs)
224            } else {
225                builder.bind(Value::Array(xs))
226            }
227        }
228        // Will always be JSONB
229        object => builder.bind(object),
230    }
231}
232
233#[derive(Debug, Deserialize, sqlx::FromRow)]
234pub struct SearchResult {
235    id: Uuid,
236    document: Value,
237    //embedded_text: String,
238    distance: f64,
239}
240
241#[derive(Debug, Deserialize, sqlx::FromRow)]
242pub struct SearchResultOnlyId {
243    id: Uuid,
244    distance: f64,
245}
246
247impl SearchResult {
248    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
249        let document: T =
250            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
251        Ok((self.distance, self.id.to_string(), document))
252    }
253}
254
255impl<Model> PostgresVectorStore<Model>
256where
257    Model: EmbeddingModel,
258{
259    pub fn new(
260        model: Model,
261        pg_pool: PgPool,
262        documents_table: Option<String>,
263        distance_function: PgVectorDistanceFunction,
264    ) -> Self {
265        Self {
266            model,
267            pg_pool,
268            documents_table: documents_table.unwrap_or(String::from("documents")),
269            distance_function,
270        }
271    }
272
273    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
274        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
275    }
276
277    fn search_query_full(
278        &self,
279        req: &VectorSearchRequest<PgSearchFilter>,
280    ) -> (String, Vec<serde_json::Value>) {
281        self.search_query(true, req)
282    }
283
284    fn search_query_only_ids(
285        &self,
286        req: &VectorSearchRequest<PgSearchFilter>,
287    ) -> (String, Vec<serde_json::Value>) {
288        self.search_query(false, req)
289    }
290
291    fn search_query(
292        &self,
293        with_document: bool,
294        req: &VectorSearchRequest<PgSearchFilter>,
295    ) -> (String, Vec<serde_json::Value>) {
296        let document = if with_document { ", document" } else { "" };
297
298        let thresh = req
299            .threshold()
300            .map(|t| PgSearchFilter::gt("distance", t.into()));
301        let filter = match (thresh, req.filter()) {
302            (Some(thresh), Some(filt)) => Some(thresh.and(filt.clone())),
303            (Some(thresh), _) => Some(thresh),
304            (_, Some(filt)) => Some(filt.clone()),
305            _ => None,
306        };
307        let (where_clause, params) = match filter {
308            Some(f) => {
309                let (expr, params) = f.into_clause();
310                (String::from("WHERE") + &expr, params)
311            }
312            None => (Default::default(), Default::default()),
313        };
314
315        let mut counter = 3;
316        let mut buf = String::with_capacity(where_clause.len() * 2);
317
318        for c in where_clause.chars() {
319            buf.push(c);
320
321            if c == '$' {
322                buf.push_str(counter.to_string().as_str());
323                counter += 1;
324            }
325        }
326
327        let where_clause = buf;
328
329        let query = format!(
330            "
331            SELECT id{}, distance FROM ( \
332              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
333              FROM {} \
334              {where_clause} \
335              ORDER BY id, distance \
336            ) as d \
337            ORDER BY distance \
338            LIMIT $2",
339            document, document, self.distance_function, self.documents_table
340        );
341
342        (query, params)
343    }
344}
345
346impl<Model> InsertDocuments for PostgresVectorStore<Model>
347where
348    Model: EmbeddingModel + Send + Sync,
349{
350    async fn insert_documents<Doc: Serialize + Embed + Send>(
351        &self,
352        documents: Vec<(Doc, OneOrMany<Embedding>)>,
353    ) -> Result<(), VectorStoreError> {
354        for (document, embeddings) in documents {
355            let id = Uuid::new_v4();
356            let json_document = serde_json::to_value(&document)?;
357
358            for embedding in embeddings {
359                let embedding_text = embedding.document;
360                let embedding: Vec<f64> = embedding.vec;
361
362                sqlx::query(
363                    format!(
364                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
365                        self.documents_table
366                    )
367                    .as_str(),
368                )
369                .bind(id)
370                .bind(&json_document)
371                .bind(&embedding_text)
372                .bind(&embedding)
373                .execute(&self.pg_pool)
374                .await
375                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
376            }
377        }
378
379        Ok(())
380    }
381}
382
383impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
384where
385    Model: EmbeddingModel,
386{
387    type Filter = PgSearchFilter;
388
389    /// Get the top n documents based on the distance to the given query.
390    /// The result is a list of tuples of the form (score, id, document)
391    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
392        &self,
393        req: VectorSearchRequest<PgSearchFilter>,
394    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
395        if req.samples() > i64::MAX as u64 {
396            return Err(VectorStoreError::DatastoreError(
397                format!(
398                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
399                    i64::MAX
400                )
401                .into(),
402            ));
403        }
404
405        let embedded_query: pgvector::Vector = self
406            .model
407            .embed_text(req.query())
408            .await?
409            .vec
410            .iter()
411            .map(|&x| x as f32)
412            .collect::<Vec<f32>>()
413            .into();
414
415        let (search_query, params) = self.search_query_full(&req);
416        let builder = sqlx::query_as(search_query.as_str())
417            .bind(embedded_query)
418            .bind(req.samples() as i64);
419
420        let builder = params.iter().cloned().fold(builder, bind_value);
421
422        let rows = builder
423            .fetch_all(&self.pg_pool)
424            .await
425            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
426
427        let rows: Vec<(f64, String, T)> = rows
428            .into_iter()
429            .flat_map(SearchResult::into_result)
430            .collect();
431
432        Ok(rows)
433    }
434
435    /// Same as `top_n` but returns the document ids only.
436    async fn top_n_ids(
437        &self,
438        req: VectorSearchRequest<PgSearchFilter>,
439    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
440        if req.samples() > i64::MAX as u64 {
441            return Err(VectorStoreError::DatastoreError(
442                format!(
443                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
444                    i64::MAX
445                )
446                .into(),
447            ));
448        }
449        let embedded_query: pgvector::Vector = self
450            .model
451            .embed_text(req.query())
452            .await?
453            .vec
454            .iter()
455            .map(|&x| x as f32)
456            .collect::<Vec<f32>>()
457            .into();
458
459        let (search_query, params) = self.search_query_only_ids(&req);
460        let builder = sqlx::query_as(search_query.as_str())
461            .bind(embedded_query)
462            .bind(req.samples() as i64);
463
464        let builder = params.iter().cloned().fold(builder, bind_value);
465
466        let rows: Vec<SearchResultOnlyId> = builder
467            .fetch_all(&self.pg_pool)
468            .await
469            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
470
471        let rows: Vec<(f64, String)> = rows
472            .into_iter()
473            .map(|row| (row.distance, row.id.to_string()))
474            .collect();
475
476        Ok(rows)
477    }
478}