rig_postgres/
lib.rs

1use std::fmt::Display;
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
8    },
9};
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use serde_json::Value;
12use sqlx::PgPool;
13use uuid::Uuid;
14
15pub struct PostgresVectorStore<Model: EmbeddingModel> {
16    model: Model,
17    pg_pool: PgPool,
18    documents_table: String,
19    distance_function: PgVectorDistanceFunction,
20}
21
22/* PgVector supported distances
23<-> - L2 distance
24<#> - (negative) inner product
25<=> - cosine distance
26<+> - L1 distance (added in 0.7.0)
27<~> - Hamming distance (binary vectors, added in 0.7.0)
28<%> - Jaccard distance (binary vectors, added in 0.7.0)
29 */
30pub enum PgVectorDistanceFunction {
31    L2,
32    InnerProduct,
33    Cosine,
34    L1,
35    Hamming,
36    Jaccard,
37}
38
39impl Display for PgVectorDistanceFunction {
40    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
41        match self {
42            PgVectorDistanceFunction::L2 => write!(f, "<->"),
43            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
44            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
45            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
46            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
47            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
48        }
49    }
50}
51
52#[derive(Debug, Deserialize, sqlx::FromRow)]
53pub struct SearchResult {
54    id: Uuid,
55    document: Value,
56    //embedded_text: String,
57    distance: f64,
58}
59
60#[derive(Debug, Deserialize, sqlx::FromRow)]
61pub struct SearchResultOnlyId {
62    id: Uuid,
63    distance: f64,
64}
65
66impl SearchResult {
67    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
68        let document: T =
69            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
70        Ok((self.distance, self.id.to_string(), document))
71    }
72}
73
74impl<Model: EmbeddingModel> PostgresVectorStore<Model> {
75    pub fn new(
76        model: Model,
77        pg_pool: PgPool,
78        documents_table: Option<String>,
79        distance_function: PgVectorDistanceFunction,
80    ) -> Self {
81        Self {
82            model,
83            pg_pool,
84            documents_table: documents_table.unwrap_or(String::from("documents")),
85            distance_function,
86        }
87    }
88
89    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
90        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
91    }
92
93    fn search_query_full(&self, threshold: Option<f64>) -> String {
94        self.search_query(true, threshold)
95    }
96    fn search_query_only_ids(&self, threshold: Option<f64>) -> String {
97        self.search_query(false, threshold)
98    }
99
100    fn search_query(&self, with_document: bool, threshold: Option<f64>) -> String {
101        let document = if with_document { ", document" } else { "" };
102        format!(
103            "
104            SELECT id{}, distance FROM ( \
105              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
106              FROM {} \
107              {where_clause}
108              ORDER BY id, distance \
109            ) as d \
110            ORDER BY distance \
111            LIMIT $2",
112            document,
113            document,
114            self.distance_function,
115            self.documents_table,
116            where_clause = if let Some(threshold) = threshold {
117                format!("where distance > {threshold}")
118            } else {
119                String::new()
120            }
121        )
122    }
123}
124
125impl<Model> InsertDocuments for PostgresVectorStore<Model>
126where
127    Model: EmbeddingModel + Send + Sync,
128{
129    async fn insert_documents<Doc: Serialize + Embed + Send>(
130        &self,
131        documents: Vec<(Doc, OneOrMany<Embedding>)>,
132    ) -> Result<(), VectorStoreError> {
133        for (document, embeddings) in documents {
134            let id = Uuid::new_v4();
135            let json_document = serde_json::to_value(&document).unwrap();
136
137            for embedding in embeddings {
138                let embedding_text = embedding.document;
139                let embedding: Vec<f64> = embedding.vec;
140
141                sqlx::query(
142                    format!(
143                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
144                        self.documents_table
145                    )
146                    .as_str(),
147                )
148                .bind(id)
149                .bind(&json_document)
150                .bind(&embedding_text)
151                .bind(&embedding)
152                .execute(&self.pg_pool)
153                .await
154                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
155            }
156        }
157
158        Ok(())
159    }
160}
161
162impl<Model: EmbeddingModel> VectorStoreIndex for PostgresVectorStore<Model> {
163    /// Get the top n documents based on the distance to the given query.
164    /// The result is a list of tuples of the form (score, id, document)
165    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
166        &self,
167        req: VectorSearchRequest,
168    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
169        if req.samples() > i64::MAX as u64 {
170            return Err(VectorStoreError::DatastoreError(
171                format!(
172                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
173                    i64::MAX
174                )
175                .into(),
176            ));
177        }
178
179        let embedded_query: pgvector::Vector = self
180            .model
181            .embed_text(req.query())
182            .await?
183            .vec
184            .iter()
185            .map(|&x| x as f32)
186            .collect::<Vec<f32>>()
187            .into();
188
189        let rows: Vec<SearchResult> =
190            sqlx::query_as(self.search_query_full(req.threshold()).as_str())
191                .bind(embedded_query)
192                .bind(req.samples() as i64)
193                .fetch_all(&self.pg_pool)
194                .await
195                .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
196
197        let rows: Vec<(f64, String, T)> = rows
198            .into_iter()
199            .flat_map(SearchResult::into_result)
200            .collect();
201
202        Ok(rows)
203    }
204
205    /// Same as `top_n` but returns the document ids only.
206    async fn top_n_ids(
207        &self,
208        req: VectorSearchRequest,
209    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
210        if req.samples() > i64::MAX as u64 {
211            return Err(VectorStoreError::DatastoreError(
212                format!(
213                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
214                    i64::MAX
215                )
216                .into(),
217            ));
218        }
219        let embedded_query: pgvector::Vector = self
220            .model
221            .embed_text(req.query())
222            .await?
223            .vec
224            .iter()
225            .map(|&x| x as f32)
226            .collect::<Vec<f32>>()
227            .into();
228
229        let rows: Vec<SearchResultOnlyId> =
230            sqlx::query_as(self.search_query_only_ids(req.threshold()).as_str())
231                .bind(embedded_query)
232                .bind(req.samples() as i64)
233                .fetch_all(&self.pg_pool)
234                .await
235                .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
236
237        let rows: Vec<(f64, String)> = rows
238            .into_iter()
239            .map(|row| (row.distance, row.id.to_string()))
240            .collect();
241
242        Ok(rows)
243    }
244}