Skip to main content

rig_postgres/
lib.rs

1use std::{fmt::Display, ops::RangeInclusive};
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, VectorStoreError, VectorStoreIndex,
8        request::{SearchFilter, VectorSearchRequest},
9    },
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use serde_json::Value;
13use sqlx::{PgPool, Postgres, postgres::PgArguments, query::QueryAs};
14use uuid::Uuid;
15
16pub struct PostgresVectorStore<Model: EmbeddingModel> {
17    model: Model,
18    pg_pool: PgPool,
19    documents_table: String,
20    distance_function: PgVectorDistanceFunction,
21}
22
23/* PgVector supported distances
24<-> - L2 distance
25<#> - (negative) inner product
26<=> - cosine distance
27<+> - L1 distance (added in 0.7.0)
28<~> - Hamming distance (binary vectors, added in 0.7.0)
29<%> - Jaccard distance (binary vectors, added in 0.7.0)
30 */
31pub enum PgVectorDistanceFunction {
32    L2,
33    InnerProduct,
34    Cosine,
35    L1,
36    Hamming,
37    Jaccard,
38}
39
40impl Display for PgVectorDistanceFunction {
41    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42        match self {
43            PgVectorDistanceFunction::L2 => write!(f, "<->"),
44            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
45            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
46            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
47            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
48            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
49        }
50    }
51}
52
53#[derive(Clone, Default, Serialize, Deserialize, Debug)]
54pub struct PgSearchFilter {
55    condition: String,
56    values: Vec<serde_json::Value>,
57}
58
59impl SearchFilter for PgSearchFilter {
60    type Value = serde_json::Value;
61
62    fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
63        Self {
64            condition: format!("{} = $", key.as_ref()),
65            values: vec![value],
66        }
67    }
68
69    fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
70        Self {
71            condition: format!("{} > $", key.as_ref()),
72            values: vec![value],
73        }
74    }
75
76    fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
77        Self {
78            condition: format!("{} < $", key.as_ref()),
79            values: vec![value],
80        }
81    }
82
83    fn and(self, rhs: Self) -> Self {
84        Self {
85            condition: format!("({}) AND ({})", self.condition, rhs.condition),
86            values: self.values.into_iter().chain(rhs.values).collect(),
87        }
88    }
89
90    fn or(self, rhs: Self) -> Self {
91        Self {
92            condition: format!("({}) OR ({})", self.condition, rhs.condition),
93            values: self.values.into_iter().chain(rhs.values).collect(),
94        }
95    }
96}
97
98impl PgSearchFilter {
99    fn into_clause(self) -> (String, Vec<serde_json::Value>) {
100        (self.condition, self.values)
101    }
102
103    #[allow(clippy::should_implement_trait)]
104    pub fn not(self) -> Self {
105        Self {
106            condition: format!("NOT ({})", self.condition),
107            values: self.values,
108        }
109    }
110
111    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
112        Self {
113            condition: format!("{key} >= ?"),
114            values: vec![value],
115        }
116    }
117
118    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
119        Self {
120            condition: format!("{key} <= ?"),
121            values: vec![value],
122        }
123    }
124
125    pub fn is_null(key: String) -> Self {
126        Self {
127            condition: format!("{key} is null"),
128            ..Default::default()
129        }
130    }
131
132    pub fn is_not_null(key: String) -> Self {
133        Self {
134            condition: format!("{key} is not null"),
135            ..Default::default()
136        }
137    }
138
139    pub fn between<T>(key: String, range: RangeInclusive<T>) -> Self
140    where
141        T: std::fmt::Display + Into<serde_json::Number> + Copy,
142    {
143        let lo = range.start();
144        let hi = range.end();
145
146        Self {
147            condition: format!("{key} between {lo} and {hi}"),
148            ..Default::default()
149        }
150    }
151
152    pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153        let placeholders = values.iter().map(|_| "?").collect::<Vec<&str>>().join(",");
154
155        Self {
156            condition: format!("{key} is in ({placeholders})"),
157            values,
158        }
159    }
160
161    // String matching ops
162
163    /// Tests whether the value at `key` matches the (case-sensitive) pattern
164    /// `pattern` should be a valid SQL string pattern, with '%' and '_' as wildcards
165    pub fn like(key: String, pattern: &'static str) -> Self {
166        Self {
167            condition: format!("{key} like {pattern}"),
168            ..Default::default()
169        }
170    }
171
172    /// Tests whether the value at `key` matches the SQL regex pattern
173    /// `pattern` should be a valid regex
174    pub fn similar_to(key: String, pattern: &'static str) -> Self {
175        Self {
176            condition: format!("{key} similar to {pattern}"),
177            ..Default::default()
178        }
179    }
180}
181
182fn bind_value<S>(
183    builder: QueryAs<'_, Postgres, S, PgArguments>,
184    value: Value,
185) -> QueryAs<'_, Postgres, S, PgArguments> {
186    match value {
187        Value::Null => builder.bind(Option::<String>::None),
188        Value::Bool(b) => builder.bind(b),
189        Value::Number(num) => {
190            if let Some(n) = num.as_f64() {
191                builder.bind(n)
192            } else if let Some(n) = num.as_i64() {
193                builder.bind(n)
194            } else if let Some(n) = num.as_u64() {
195                builder.bind(n as i64)
196            } else {
197                builder.bind(num.to_string())
198            }
199        }
200        Value::String(s) => builder.bind(s),
201        Value::Array(xs) => {
202            if let Some(xs) = xs
203                .iter()
204                .map(|v| v.as_str().map(str::to_string))
205                .collect::<Option<Vec<_>>>()
206            {
207                builder.bind(xs)
208            } else if let Some(xs) = xs.iter().map(Value::as_f64).collect::<Option<Vec<_>>>() {
209                builder.bind(xs)
210            } else if let Some(xs) = xs.iter().map(Value::as_i64).collect::<Option<Vec<_>>>() {
211                builder.bind(xs)
212            } else if let Some(xs) = xs.iter().map(Value::as_bool).collect::<Option<Vec<_>>>() {
213                builder.bind(xs)
214            } else {
215                builder.bind(Value::Array(xs))
216            }
217        }
218        // Will always be JSONB
219        object => builder.bind(object),
220    }
221}
222
223#[derive(Debug, Deserialize, sqlx::FromRow)]
224pub struct SearchResult {
225    id: Uuid,
226    document: Value,
227    //embedded_text: String,
228    distance: f64,
229}
230
231#[derive(Debug, Deserialize, sqlx::FromRow)]
232pub struct SearchResultOnlyId {
233    id: Uuid,
234    distance: f64,
235}
236
237impl SearchResult {
238    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
239        let document: T =
240            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
241        Ok((self.distance, self.id.to_string(), document))
242    }
243}
244
245impl<Model> PostgresVectorStore<Model>
246where
247    Model: EmbeddingModel,
248{
249    pub fn new(
250        model: Model,
251        pg_pool: PgPool,
252        documents_table: Option<String>,
253        distance_function: PgVectorDistanceFunction,
254    ) -> Self {
255        Self {
256            model,
257            pg_pool,
258            documents_table: documents_table.unwrap_or(String::from("documents")),
259            distance_function,
260        }
261    }
262
263    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
264        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
265    }
266
267    fn search_query_full(
268        &self,
269        req: &VectorSearchRequest<PgSearchFilter>,
270    ) -> (String, Vec<serde_json::Value>) {
271        self.search_query(true, req)
272    }
273
274    fn search_query_only_ids(
275        &self,
276        req: &VectorSearchRequest<PgSearchFilter>,
277    ) -> (String, Vec<serde_json::Value>) {
278        self.search_query(false, req)
279    }
280
281    fn search_query(
282        &self,
283        with_document: bool,
284        req: &VectorSearchRequest<PgSearchFilter>,
285    ) -> (String, Vec<serde_json::Value>) {
286        let document = if with_document { ", document" } else { "" };
287
288        let thresh = req
289            .threshold()
290            .map(|t| PgSearchFilter::gt("distance", t.into()));
291        let filter = match (thresh, req.filter()) {
292            (Some(thresh), Some(filt)) => Some(thresh.and(filt.clone())),
293            (Some(thresh), _) => Some(thresh),
294            (_, Some(filt)) => Some(filt.clone()),
295            _ => None,
296        };
297        let (where_clause, params) = match filter {
298            Some(f) => {
299                let (expr, params) = f.into_clause();
300                (String::from("WHERE") + &expr, params)
301            }
302            None => (Default::default(), Default::default()),
303        };
304
305        let mut counter = 3;
306        let mut buf = String::with_capacity(where_clause.len() * 2);
307
308        for c in where_clause.chars() {
309            buf.push(c);
310
311            if c == '$' {
312                buf.push_str(counter.to_string().as_str());
313                counter += 1;
314            }
315        }
316
317        let where_clause = buf;
318
319        let query = format!(
320            "
321            SELECT id{}, distance FROM ( \
322              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
323              FROM {} \
324              {where_clause} \
325              ORDER BY id, distance \
326            ) as d \
327            ORDER BY distance \
328            LIMIT $2",
329            document, document, self.distance_function, self.documents_table
330        );
331
332        (query, params)
333    }
334}
335
336impl<Model> InsertDocuments for PostgresVectorStore<Model>
337where
338    Model: EmbeddingModel + Send + Sync,
339{
340    async fn insert_documents<Doc: Serialize + Embed + Send>(
341        &self,
342        documents: Vec<(Doc, OneOrMany<Embedding>)>,
343    ) -> Result<(), VectorStoreError> {
344        for (document, embeddings) in documents {
345            let id = Uuid::new_v4();
346            let json_document = serde_json::to_value(&document)?;
347
348            for embedding in embeddings {
349                let embedding_text = embedding.document;
350                let embedding: Vec<f64> = embedding.vec;
351
352                sqlx::query(
353                    format!(
354                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
355                        self.documents_table
356                    )
357                    .as_str(),
358                )
359                .bind(id)
360                .bind(&json_document)
361                .bind(&embedding_text)
362                .bind(&embedding)
363                .execute(&self.pg_pool)
364                .await
365                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
366            }
367        }
368
369        Ok(())
370    }
371}
372
373impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
374where
375    Model: EmbeddingModel,
376{
377    type Filter = PgSearchFilter;
378
379    /// Get the top n documents based on the distance to the given query.
380    /// The result is a list of tuples of the form (score, id, document)
381    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
382        &self,
383        req: VectorSearchRequest<PgSearchFilter>,
384    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
385        if req.samples() > i64::MAX as u64 {
386            return Err(VectorStoreError::DatastoreError(
387                format!(
388                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
389                    i64::MAX
390                )
391                .into(),
392            ));
393        }
394
395        let embedded_query: pgvector::Vector = self
396            .model
397            .embed_text(req.query())
398            .await?
399            .vec
400            .iter()
401            .map(|&x| x as f32)
402            .collect::<Vec<f32>>()
403            .into();
404
405        let (search_query, params) = self.search_query_full(&req);
406        let builder = sqlx::query_as(search_query.as_str())
407            .bind(embedded_query)
408            .bind(req.samples() as i64);
409
410        let builder = params.iter().cloned().fold(builder, bind_value);
411
412        let rows = builder
413            .fetch_all(&self.pg_pool)
414            .await
415            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
416
417        let rows: Vec<(f64, String, T)> = rows
418            .into_iter()
419            .flat_map(SearchResult::into_result)
420            .collect();
421
422        Ok(rows)
423    }
424
425    /// Same as `top_n` but returns the document ids only.
426    async fn top_n_ids(
427        &self,
428        req: VectorSearchRequest<PgSearchFilter>,
429    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
430        if req.samples() > i64::MAX as u64 {
431            return Err(VectorStoreError::DatastoreError(
432                format!(
433                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
434                    i64::MAX
435                )
436                .into(),
437            ));
438        }
439        let embedded_query: pgvector::Vector = self
440            .model
441            .embed_text(req.query())
442            .await?
443            .vec
444            .iter()
445            .map(|&x| x as f32)
446            .collect::<Vec<f32>>()
447            .into();
448
449        let (search_query, params) = self.search_query_only_ids(&req);
450        let builder = sqlx::query_as(search_query.as_str())
451            .bind(embedded_query)
452            .bind(req.samples() as i64);
453
454        let builder = params.iter().cloned().fold(builder, bind_value);
455
456        let rows: Vec<SearchResultOnlyId> = builder
457            .fetch_all(&self.pg_pool)
458            .await
459            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
460
461        let rows: Vec<(f64, String)> = rows
462            .into_iter()
463            .map(|row| (row.distance, row.id.to_string()))
464            .collect();
465
466        Ok(rows)
467    }
468}