Skip to main content

rig_postgres/
lib.rs

1use std::{fmt::Display, ops::RangeInclusive};
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, VectorStoreError, VectorStoreIndex,
8        request::{SearchFilter, VectorSearchRequest},
9    },
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use serde_json::Value;
13use sqlx::{PgPool, Postgres, postgres::PgArguments, query::QueryAs};
14use uuid::Uuid;
15
16pub struct PostgresVectorStore<Model: EmbeddingModel> {
17    model: Model,
18    pg_pool: PgPool,
19    documents_table: String,
20    distance_function: PgVectorDistanceFunction,
21}
22
23/* PgVector supported distances
24<-> - L2 distance
25<#> - (negative) inner product
26<=> - cosine distance
27<+> - L1 distance (added in 0.7.0)
28<~> - Hamming distance (binary vectors, added in 0.7.0)
29<%> - Jaccard distance (binary vectors, added in 0.7.0)
30 */
31pub enum PgVectorDistanceFunction {
32    L2,
33    InnerProduct,
34    Cosine,
35    L1,
36    Hamming,
37    Jaccard,
38}
39
40impl Display for PgVectorDistanceFunction {
41    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42        match self {
43            PgVectorDistanceFunction::L2 => write!(f, "<->"),
44            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
45            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
46            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
47            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
48            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
49        }
50    }
51}
52
53#[derive(Clone, Default, Serialize, Deserialize, Debug)]
54pub struct PgSearchFilter {
55    condition: String,
56    values: Vec<serde_json::Value>,
57}
58
59impl SearchFilter for PgSearchFilter {
60    type Value = serde_json::Value;
61
62    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
63        Self {
64            condition: format!("{} = $", key.as_ref()),
65            values: vec![value],
66        }
67    }
68
69    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
70        Self {
71            condition: format!("{} > $", key.as_ref()),
72            values: vec![value],
73        }
74    }
75
76    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
77        Self {
78            condition: format!("{} < $", key.as_ref()),
79            values: vec![value],
80        }
81    }
82
83    fn and(self, rhs: Self) -> Self {
84        Self {
85            condition: format!("({}) AND ({})", self.condition, rhs.condition),
86            values: self.values.into_iter().chain(rhs.values).collect(),
87        }
88    }
89
90    fn or(self, rhs: Self) -> Self {
91        Self {
92            condition: format!("({}) OR ({})", self.condition, rhs.condition),
93            values: self.values.into_iter().chain(rhs.values).collect(),
94        }
95    }
96}
97
98impl PgSearchFilter {
99    fn into_clause(self) -> (String, Vec<serde_json::Value>) {
100        (self.condition, self.values)
101    }
102
103    #[allow(clippy::should_implement_trait)]
104    pub fn not(self) -> Self {
105        Self {
106            condition: format!("NOT ({})", self.condition),
107            values: self.values,
108        }
109    }
110
111    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
112        Self {
113            condition: format!("{key} >= ?"),
114            values: vec![value],
115        }
116    }
117
118    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
119        Self {
120            condition: format!("{key} <= ?"),
121            values: vec![value],
122        }
123    }
124
125    pub fn is_null(key: String) -> Self {
126        Self {
127            condition: format!("{key} is null"),
128            ..Default::default()
129        }
130    }
131
132    pub fn is_not_null(key: String) -> Self {
133        Self {
134            condition: format!("{key} is not null"),
135            ..Default::default()
136        }
137    }
138
139    pub fn between<T>(key: String, range: RangeInclusive<T>) -> Self
140    where
141        T: std::fmt::Display + Into<serde_json::Number> + Copy,
142    {
143        let lo = range.start();
144        let hi = range.end();
145
146        Self {
147            condition: format!("{key} between {lo} and {hi}"),
148            ..Default::default()
149        }
150    }
151
152    pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153        let placeholders = values.iter().map(|_| "?").collect::<Vec<&str>>().join(",");
154
155        Self {
156            condition: format!("{key} is in ({placeholders})"),
157            values,
158        }
159    }
160
161    // String matching ops
162
163    /// Tests whether the value at `key` matches the (case-sensitive) pattern
164    /// `pattern` should be a valid SQL string pattern, with '%' and '_' as wildcards
165    pub fn like(key: String, pattern: &'static str) -> Self {
166        Self {
167            condition: format!("{key} like {pattern}"),
168            ..Default::default()
169        }
170    }
171
172    /// Tests whether the value at `key` matches the SQL regex pattern
173    /// `pattern` should be a valid regex
174    pub fn similar_to(key: String, pattern: &'static str) -> Self {
175        Self {
176            condition: format!("{key} similar to {pattern}"),
177            ..Default::default()
178        }
179    }
180}
181
182fn bind_value<S>(
183    builder: QueryAs<'_, Postgres, S, PgArguments>,
184    value: Value,
185) -> QueryAs<'_, Postgres, S, PgArguments> {
186    match value {
187        Value::Null => unreachable!(),
188        Value::Bool(b) => builder.bind(b),
189        Value::Number(num) => {
190            if let Some(n) = num.as_f64() {
191                builder.bind(n)
192            } else if let Some(n) = num.as_i64() {
193                builder.bind(n)
194            } else {
195                unreachable!()
196            }
197        }
198        Value::String(s) => builder.bind(s),
199        Value::Array(xs) => {
200            if let Some(xs) = xs
201                .iter()
202                .map(|v| v.as_str().map(str::to_string))
203                .collect::<Option<Vec<_>>>()
204            {
205                builder.bind(xs)
206            } else if let Some(xs) = xs.iter().map(Value::as_f64).collect::<Option<Vec<_>>>() {
207                builder.bind(xs)
208            } else if let Some(xs) = xs.iter().map(Value::as_i64).collect::<Option<Vec<_>>>() {
209                builder.bind(xs)
210            } else if let Some(xs) = xs.iter().map(Value::as_bool).collect::<Option<Vec<_>>>() {
211                builder.bind(xs)
212            } else {
213                builder.bind(Value::Array(xs))
214            }
215        }
216        // Will always be JSONB
217        object => builder.bind(object),
218    }
219}
220
221#[derive(Debug, Deserialize, sqlx::FromRow)]
222pub struct SearchResult {
223    id: Uuid,
224    document: Value,
225    //embedded_text: String,
226    distance: f64,
227}
228
229#[derive(Debug, Deserialize, sqlx::FromRow)]
230pub struct SearchResultOnlyId {
231    id: Uuid,
232    distance: f64,
233}
234
235impl SearchResult {
236    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
237        let document: T =
238            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
239        Ok((self.distance, self.id.to_string(), document))
240    }
241}
242
243impl<Model> PostgresVectorStore<Model>
244where
245    Model: EmbeddingModel,
246{
247    pub fn new(
248        model: Model,
249        pg_pool: PgPool,
250        documents_table: Option<String>,
251        distance_function: PgVectorDistanceFunction,
252    ) -> Self {
253        Self {
254            model,
255            pg_pool,
256            documents_table: documents_table.unwrap_or(String::from("documents")),
257            distance_function,
258        }
259    }
260
261    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
262        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
263    }
264
265    fn search_query_full(
266        &self,
267        req: &VectorSearchRequest<PgSearchFilter>,
268    ) -> (String, Vec<serde_json::Value>) {
269        self.search_query(true, req)
270    }
271
272    fn search_query_only_ids(
273        &self,
274        req: &VectorSearchRequest<PgSearchFilter>,
275    ) -> (String, Vec<serde_json::Value>) {
276        self.search_query(false, req)
277    }
278
279    fn search_query(
280        &self,
281        with_document: bool,
282        req: &VectorSearchRequest<PgSearchFilter>,
283    ) -> (String, Vec<serde_json::Value>) {
284        let document = if with_document { ", document" } else { "" };
285
286        let thresh = req
287            .threshold()
288            .map(|t| PgSearchFilter::gt("distance", t.into()));
289        let filter = match (thresh, req.filter()) {
290            (Some(thresh), Some(filt)) => Some(thresh.and(filt.clone())),
291            (Some(thresh), _) => Some(thresh),
292            (_, Some(filt)) => Some(filt.clone()),
293            _ => None,
294        };
295        let (where_clause, params) = match filter {
296            Some(f) => {
297                let (expr, params) = f.into_clause();
298                (String::from("WHERE") + &expr, params)
299            }
300            None => (Default::default(), Default::default()),
301        };
302
303        let mut counter = 3;
304        let mut buf = String::with_capacity(where_clause.len() * 2);
305
306        for c in where_clause.chars() {
307            buf.push(c);
308
309            if c == '$' {
310                buf.push_str(counter.to_string().as_str());
311                counter += 1;
312            }
313        }
314
315        let where_clause = buf;
316
317        let query = format!(
318            "
319            SELECT id{}, distance FROM ( \
320              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
321              FROM {} \
322              {where_clause} \
323              ORDER BY id, distance \
324            ) as d \
325            ORDER BY distance \
326            LIMIT $2",
327            document, document, self.distance_function, self.documents_table
328        );
329
330        (query, params)
331    }
332}
333
334impl<Model> InsertDocuments for PostgresVectorStore<Model>
335where
336    Model: EmbeddingModel + Send + Sync,
337{
338    async fn insert_documents<Doc: Serialize + Embed + Send>(
339        &self,
340        documents: Vec<(Doc, OneOrMany<Embedding>)>,
341    ) -> Result<(), VectorStoreError> {
342        for (document, embeddings) in documents {
343            let id = Uuid::new_v4();
344            let json_document = serde_json::to_value(&document).unwrap();
345
346            for embedding in embeddings {
347                let embedding_text = embedding.document;
348                let embedding: Vec<f64> = embedding.vec;
349
350                sqlx::query(
351                    format!(
352                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
353                        self.documents_table
354                    )
355                    .as_str(),
356                )
357                .bind(id)
358                .bind(&json_document)
359                .bind(&embedding_text)
360                .bind(&embedding)
361                .execute(&self.pg_pool)
362                .await
363                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
364            }
365        }
366
367        Ok(())
368    }
369}
370
371impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
372where
373    Model: EmbeddingModel,
374{
375    type Filter = PgSearchFilter;
376
377    /// Get the top n documents based on the distance to the given query.
378    /// The result is a list of tuples of the form (score, id, document)
379    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
380        &self,
381        req: VectorSearchRequest<PgSearchFilter>,
382    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
383        if req.samples() > i64::MAX as u64 {
384            return Err(VectorStoreError::DatastoreError(
385                format!(
386                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
387                    i64::MAX
388                )
389                .into(),
390            ));
391        }
392
393        let embedded_query: pgvector::Vector = self
394            .model
395            .embed_text(req.query())
396            .await?
397            .vec
398            .iter()
399            .map(|&x| x as f32)
400            .collect::<Vec<f32>>()
401            .into();
402
403        let (search_query, params) = self.search_query_full(&req);
404        let builder = sqlx::query_as(search_query.as_str())
405            .bind(embedded_query)
406            .bind(req.samples() as i64);
407
408        let builder = params.iter().cloned().fold(builder, bind_value);
409
410        let rows = builder
411            .fetch_all(&self.pg_pool)
412            .await
413            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
414
415        let rows: Vec<(f64, String, T)> = rows
416            .into_iter()
417            .flat_map(SearchResult::into_result)
418            .collect();
419
420        Ok(rows)
421    }
422
423    /// Same as `top_n` but returns the document ids only.
424    async fn top_n_ids(
425        &self,
426        req: VectorSearchRequest<PgSearchFilter>,
427    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
428        if req.samples() > i64::MAX as u64 {
429            return Err(VectorStoreError::DatastoreError(
430                format!(
431                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
432                    i64::MAX
433                )
434                .into(),
435            ));
436        }
437        let embedded_query: pgvector::Vector = self
438            .model
439            .embed_text(req.query())
440            .await?
441            .vec
442            .iter()
443            .map(|&x| x as f32)
444            .collect::<Vec<f32>>()
445            .into();
446
447        let (search_query, params) = self.search_query_only_ids(&req);
448        let builder = sqlx::query_as(search_query.as_str())
449            .bind(embedded_query)
450            .bind(req.samples() as i64);
451
452        let builder = params.iter().cloned().fold(builder, bind_value);
453
454        let rows: Vec<SearchResultOnlyId> = builder
455            .fetch_all(&self.pg_pool)
456            .await
457            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
458
459        let rows: Vec<(f64, String)> = rows
460            .into_iter()
461            .map(|row| (row.distance, row.id.to_string()))
462            .collect();
463
464        Ok(rows)
465    }
466}