rig_postgres/
lib.rs

1use std::fmt::Display;
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
8    },
9};
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use serde_json::Value;
12use sqlx::PgPool;
13use uuid::Uuid;
14
15pub struct PostgresVectorStore<Model: EmbeddingModel> {
16    model: Model,
17    pg_pool: PgPool,
18    documents_table: String,
19    distance_function: PgVectorDistanceFunction,
20}
21
22/* PgVector supported distances
23<-> - L2 distance
24<#> - (negative) inner product
25<=> - cosine distance
26<+> - L1 distance (added in 0.7.0)
27<~> - Hamming distance (binary vectors, added in 0.7.0)
28<%> - Jaccard distance (binary vectors, added in 0.7.0)
29 */
30pub enum PgVectorDistanceFunction {
31    L2,
32    InnerProduct,
33    Cosine,
34    L1,
35    Hamming,
36    Jaccard,
37}
38
39impl Display for PgVectorDistanceFunction {
40    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
41        match self {
42            PgVectorDistanceFunction::L2 => write!(f, "<->"),
43            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
44            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
45            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
46            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
47            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
48        }
49    }
50}
51
52#[derive(Debug, Deserialize, sqlx::FromRow)]
53pub struct SearchResult {
54    id: Uuid,
55    document: Value,
56    //embedded_text: String,
57    distance: f64,
58}
59
60#[derive(Debug, Deserialize, sqlx::FromRow)]
61pub struct SearchResultOnlyId {
62    id: Uuid,
63    distance: f64,
64}
65
66impl SearchResult {
67    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
68        let document: T =
69            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
70        Ok((self.distance, self.id.to_string(), document))
71    }
72}
73
74impl<Model> PostgresVectorStore<Model>
75where
76    Model: EmbeddingModel,
77{
78    pub fn new(
79        model: Model,
80        pg_pool: PgPool,
81        documents_table: Option<String>,
82        distance_function: PgVectorDistanceFunction,
83    ) -> Self {
84        Self {
85            model,
86            pg_pool,
87            documents_table: documents_table.unwrap_or(String::from("documents")),
88            distance_function,
89        }
90    }
91
92    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
93        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
94    }
95
96    fn search_query_full(&self, threshold: Option<f64>) -> String {
97        self.search_query(true, threshold)
98    }
99    fn search_query_only_ids(&self, threshold: Option<f64>) -> String {
100        self.search_query(false, threshold)
101    }
102
103    fn search_query(&self, with_document: bool, threshold: Option<f64>) -> String {
104        let document = if with_document { ", document" } else { "" };
105        format!(
106            "
107            SELECT id{}, distance FROM ( \
108              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
109              FROM {} \
110              {where_clause}
111              ORDER BY id, distance \
112            ) as d \
113            ORDER BY distance \
114            LIMIT $2",
115            document,
116            document,
117            self.distance_function,
118            self.documents_table,
119            where_clause = if let Some(threshold) = threshold {
120                format!("where distance > {threshold}")
121            } else {
122                String::new()
123            }
124        )
125    }
126}
127
128impl<Model> InsertDocuments for PostgresVectorStore<Model>
129where
130    Model: EmbeddingModel + Send + Sync,
131{
132    async fn insert_documents<Doc: Serialize + Embed + Send>(
133        &self,
134        documents: Vec<(Doc, OneOrMany<Embedding>)>,
135    ) -> Result<(), VectorStoreError> {
136        for (document, embeddings) in documents {
137            let id = Uuid::new_v4();
138            let json_document = serde_json::to_value(&document).unwrap();
139
140            for embedding in embeddings {
141                let embedding_text = embedding.document;
142                let embedding: Vec<f64> = embedding.vec;
143
144                sqlx::query(
145                    format!(
146                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
147                        self.documents_table
148                    )
149                    .as_str(),
150                )
151                .bind(id)
152                .bind(&json_document)
153                .bind(&embedding_text)
154                .bind(&embedding)
155                .execute(&self.pg_pool)
156                .await
157                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
158            }
159        }
160
161        Ok(())
162    }
163}
164
165impl<Model> VectorStoreIndex for PostgresVectorStore<Model>
166where
167    Model: EmbeddingModel,
168{
169    /// Get the top n documents based on the distance to the given query.
170    /// The result is a list of tuples of the form (score, id, document)
171    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
172        &self,
173        req: VectorSearchRequest,
174    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
175        if req.samples() > i64::MAX as u64 {
176            return Err(VectorStoreError::DatastoreError(
177                format!(
178                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
179                    i64::MAX
180                )
181                .into(),
182            ));
183        }
184
185        let embedded_query: pgvector::Vector = self
186            .model
187            .embed_text(req.query())
188            .await?
189            .vec
190            .iter()
191            .map(|&x| x as f32)
192            .collect::<Vec<f32>>()
193            .into();
194
195        let rows: Vec<SearchResult> =
196            sqlx::query_as(self.search_query_full(req.threshold()).as_str())
197                .bind(embedded_query)
198                .bind(req.samples() as i64)
199                .fetch_all(&self.pg_pool)
200                .await
201                .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
202
203        let rows: Vec<(f64, String, T)> = rows
204            .into_iter()
205            .flat_map(SearchResult::into_result)
206            .collect();
207
208        Ok(rows)
209    }
210
211    /// Same as `top_n` but returns the document ids only.
212    async fn top_n_ids(
213        &self,
214        req: VectorSearchRequest,
215    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
216        if req.samples() > i64::MAX as u64 {
217            return Err(VectorStoreError::DatastoreError(
218                format!(
219                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
220                    i64::MAX
221                )
222                .into(),
223            ));
224        }
225        let embedded_query: pgvector::Vector = self
226            .model
227            .embed_text(req.query())
228            .await?
229            .vec
230            .iter()
231            .map(|&x| x as f32)
232            .collect::<Vec<f32>>()
233            .into();
234
235        let rows: Vec<SearchResultOnlyId> =
236            sqlx::query_as(self.search_query_only_ids(req.threshold()).as_str())
237                .bind(embedded_query)
238                .bind(req.samples() as i64)
239                .fetch_all(&self.pg_pool)
240                .await
241                .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
242
243        let rows: Vec<(f64, String)> = rows
244            .into_iter()
245            .map(|row| (row.distance, row.id.to_string()))
246            .collect();
247
248        Ok(rows)
249    }
250}