rig_postgres/
lib.rs

1use std::{fmt::Display, ops::RangeInclusive};
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, VectorStoreError, VectorStoreIndex,
8        request::{SearchFilter, VectorSearchRequest},
9    },
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use serde_json::Value;
13use sqlx::PgPool;
14use uuid::Uuid;
15
16pub struct PostgresVectorStore<Model: EmbeddingModel> {
17    model: Model,
18    pg_pool: PgPool,
19    documents_table: String,
20    distance_function: PgVectorDistanceFunction,
21}
22
23/* PgVector supported distances
24<-> - L2 distance
25<#> - (negative) inner product
26<=> - cosine distance
27<+> - L1 distance (added in 0.7.0)
28<~> - Hamming distance (binary vectors, added in 0.7.0)
29<%> - Jaccard distance (binary vectors, added in 0.7.0)
30 */
31pub enum PgVectorDistanceFunction {
32    L2,
33    InnerProduct,
34    Cosine,
35    L1,
36    Hamming,
37    Jaccard,
38}
39
40impl Display for PgVectorDistanceFunction {
41    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42        match self {
43            PgVectorDistanceFunction::L2 => write!(f, "<->"),
44            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
45            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
46            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
47            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
48            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
49        }
50    }
51}
52
53#[derive(Clone, Default)]
54pub struct PgSearchFilter {
55    condition: String,
56    values: Vec<serde_json::Value>,
57}
58
59impl SearchFilter for PgSearchFilter {
60    type Value = serde_json::Value;
61
62    fn eq(key: String, value: Self::Value) -> Self {
63        Self {
64            condition: format!("{key} = $"),
65            values: vec![value],
66        }
67    }
68
69    fn gt(key: String, value: Self::Value) -> Self {
70        Self {
71            condition: format!("{key} > $"),
72            values: vec![value],
73        }
74    }
75
76    fn lt(key: String, value: Self::Value) -> Self {
77        Self {
78            condition: format!("{key} < $"),
79            values: vec![value],
80        }
81    }
82
83    fn and(self, rhs: Self) -> Self {
84        Self {
85            condition: format!("({}) AND ({})", self.condition, rhs.condition),
86            values: self.values.into_iter().chain(rhs.values).collect(),
87        }
88    }
89
90    fn or(self, rhs: Self) -> Self {
91        Self {
92            condition: format!("({}) OR ({})", self.condition, rhs.condition),
93            values: self.values.into_iter().chain(rhs.values).collect(),
94        }
95    }
96}
97
98impl PgSearchFilter {
99    fn into_clause(self) -> (String, Vec<serde_json::Value>) {
100        (self.condition, self.values)
101    }
102
103    #[allow(clippy::should_implement_trait)]
104    pub fn not(self) -> Self {
105        Self {
106            condition: format!("NOT ({})", self.condition),
107            values: self.values,
108        }
109    }
110
111    pub fn gte(key: String, value: <Self as SearchFilter>::Value) -> Self {
112        Self {
113            condition: format!("{key} >= ?"),
114            values: vec![value],
115        }
116    }
117
118    pub fn lte(key: String, value: <Self as SearchFilter>::Value) -> Self {
119        Self {
120            condition: format!("{key} <= ?"),
121            values: vec![value],
122        }
123    }
124
125    pub fn is_null(key: String) -> Self {
126        Self {
127            condition: format!("{key} is null"),
128            ..Default::default()
129        }
130    }
131
132    pub fn is_not_null(key: String) -> Self {
133        Self {
134            condition: format!("{key} is not null"),
135            ..Default::default()
136        }
137    }
138
139    pub fn between<T>(key: String, range: RangeInclusive<T>) -> Self
140    where
141        T: std::fmt::Display + Into<serde_json::Number> + Copy,
142    {
143        let lo = range.start();
144        let hi = range.end();
145
146        Self {
147            condition: format!("{key} between {lo} and {hi}"),
148            ..Default::default()
149        }
150    }
151
152    pub fn member(key: String, values: Vec<<Self as SearchFilter>::Value>) -> Self {
153        let placeholders = values.iter().map(|_| "?").collect::<Vec<&str>>().join(",");
154
155        Self {
156            condition: format!("{key} is in ({placeholders})"),
157            values,
158        }
159    }
160
161    // String matching ops
162
163    /// Tests whether the value at `key` matches the (case-sensitive) pattern
164    /// `pattern` should be a valid SQL string pattern, with '%' and '_' as wildcards
165    pub fn like(key: String, pattern: &'static str) -> Self {
166        Self {
167            condition: format!("{key} like {pattern}"),
168            ..Default::default()
169        }
170    }
171
172    /// Tests whether the value at `key` matches the SQL regex pattern
173    /// `pattern` should be a valid regex
174    pub fn similar_to(key: String, pattern: &'static str) -> Self {
175        Self {
176            condition: format!("{key} similar to {pattern}"),
177            ..Default::default()
178        }
179    }
180}
181
182#[derive(Debug, Deserialize, sqlx::FromRow)]
183pub struct SearchResult {
184    id: Uuid,
185    document: Value,
186    //embedded_text: String,
187    distance: f64,
188}
189
190#[derive(Debug, Deserialize, sqlx::FromRow)]
191pub struct SearchResultOnlyId {
192    id: Uuid,
193    distance: f64,
194}
195
196impl SearchResult {
197    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
198        let document: T =
199            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
200        Ok((self.distance, self.id.to_string(), document))
201    }
202}
203
204impl<Model> PostgresVectorStore<Model>
205where
206    Model: EmbeddingModel,
207{
208    pub fn new(
209        model: Model,
210        pg_pool: PgPool,
211        documents_table: Option<String>,
212        distance_function: PgVectorDistanceFunction,
213    ) -> Self {
214        Self {
215            model,
216            pg_pool,
217            documents_table: documents_table.unwrap_or(String::from("documents")),
218            distance_function,
219        }
220    }
221
222    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
223        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
224    }
225
226    fn search_query_full(
227        &self,
228        req: &VectorSearchRequest<PgSearchFilter>,
229    ) -> (String, Vec<serde_json::Value>) {
230        self.search_query(true, req)
231    }
232
233    fn search_query_only_ids(
234        &self,
235        req: &VectorSearchRequest<PgSearchFilter>,
236    ) -> (String, Vec<serde_json::Value>) {
237        self.search_query(false, req)
238    }
239
240    fn search_query(
241        &self,
242        with_document: bool,
243        req: &VectorSearchRequest<PgSearchFilter>,
244    ) -> (String, Vec<serde_json::Value>) {
245        let document = if with_document { ", document" } else { "" };
246
247        let thresh = req
248            .threshold()
249            .map(|t| PgSearchFilter::gt("distance".into(), t.into()));
250        let filter = match (thresh, req.filter()) {
251            (Some(thresh), Some(filt)) => Some(thresh.and(filt.clone())),
252            (Some(thresh), _) => Some(thresh),
253            (_, Some(filt)) => Some(filt.clone()),
254            _ => None,
255        };
256        let (where_clause, params) = match filter {
257            Some(f) => {
258                let (expr, params) = f.into_clause();
259                (String::from("WHERE") + &expr, params)
260            }
261            None => (Default::default(), Default::default()),
262        };
263
264        let mut counter = 3;
265        let mut buf = String::with_capacity(where_clause.len() * 2);
266
267        for c in where_clause.chars() {
268            buf.push(c);
269
270            if c == '$' {
271                buf.push_str(counter.to_string().as_str());
272                counter += 1;
273            }
274        }
275
276        let where_clause = buf;
277
278        let query = format!(
279            "
280            SELECT id{}, distance FROM ( \
281              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
282              FROM {} \
283              {where_clause} \
284              ORDER BY id, distance \
285            ) as d \
286            ORDER BY distance \
287            LIMIT $2",
288            document, document, self.distance_function, self.documents_table
289        );
290
291        (query, params)
292    }
293}
294
295impl<Model> InsertDocuments for PostgresVectorStore<Model>
296where
297    Model: EmbeddingModel + Send + Sync,
298{
299    async fn insert_documents<Doc: Serialize + Embed + Send>(
300        &self,
301        documents: Vec<(Doc, OneOrMany<Embedding>)>,
302    ) -> Result<(), VectorStoreError> {
303        for (document, embeddings) in documents {
304            let id = Uuid::new_v4();
305            let json_document = serde_json::to_value(&document).unwrap();
306
307            for embedding in embeddings {
308                let embedding_text = embedding.document;
309                let embedding: Vec<f64> = embedding.vec;
310
311                sqlx::query(
312                    format!(
313                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
314                        self.documents_table
315                    )
316                    .as_str(),
317                )
318                .bind(id)
319                .bind(&json_document)
320                .bind(&embedding_text)
321                .bind(&embedding)
322                .execute(&self.pg_pool)
323                .await
324                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
325            }
326        }
327
328        Ok(())
329    }
330}
331
332impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
333where
334    Model: EmbeddingModel,
335{
336    type Filter = PgSearchFilter;
337
338    /// Get the top n documents based on the distance to the given query.
339    /// The result is a list of tuples of the form (score, id, document)
340    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
341        &self,
342        req: VectorSearchRequest<PgSearchFilter>,
343    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
344        if req.samples() > i64::MAX as u64 {
345            return Err(VectorStoreError::DatastoreError(
346                format!(
347                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
348                    i64::MAX
349                )
350                .into(),
351            ));
352        }
353
354        let embedded_query: pgvector::Vector = self
355            .model
356            .embed_text(req.query())
357            .await?
358            .vec
359            .iter()
360            .map(|&x| x as f32)
361            .collect::<Vec<f32>>()
362            .into();
363
364        let (search_query, params) = self.search_query_full(&req);
365        let builder = sqlx::query_as(search_query.as_str())
366            .bind(embedded_query)
367            .bind(req.samples() as i64);
368
369        let builder = params
370            .iter()
371            .fold(builder, |builder, param| builder.bind(param));
372
373        let rows = builder
374            .fetch_all(&self.pg_pool)
375            .await
376            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
377
378        let rows: Vec<(f64, String, T)> = rows
379            .into_iter()
380            .flat_map(SearchResult::into_result)
381            .collect();
382
383        Ok(rows)
384    }
385
386    /// Same as `top_n` but returns the document ids only.
387    async fn top_n_ids(
388        &self,
389        req: VectorSearchRequest<PgSearchFilter>,
390    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
391        if req.samples() > i64::MAX as u64 {
392            return Err(VectorStoreError::DatastoreError(
393                format!(
394                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
395                    i64::MAX
396                )
397                .into(),
398            ));
399        }
400        let embedded_query: pgvector::Vector = self
401            .model
402            .embed_text(req.query())
403            .await?
404            .vec
405            .iter()
406            .map(|&x| x as f32)
407            .collect::<Vec<f32>>()
408            .into();
409
410        let (search_query, params) = self.search_query_only_ids(&req);
411        let builder = sqlx::query_as(search_query.as_str())
412            .bind(embedded_query)
413            .bind(req.samples() as i64);
414
415        let builder = params
416            .iter()
417            .fold(builder, |builder, param| builder.bind(param));
418
419        let rows: Vec<SearchResultOnlyId> = builder
420            .fetch_all(&self.pg_pool)
421            .await
422            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
423
424        let rows: Vec<(f64, String)> = rows
425            .into_iter()
426            .map(|row| (row.distance, row.id.to_string()))
427            .collect();
428
429        Ok(rows)
430    }
431}