rig_postgres/
lib.rs

1use std::fmt::Display;
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{
7        InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
8    },
9};
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use serde_json::Value;
12use sqlx::PgPool;
13use uuid::Uuid;
14
15pub struct PostgresVectorStore<Model: EmbeddingModel> {
16    model: Model,
17    pg_pool: PgPool,
18    documents_table: String,
19    distance_function: PgVectorDistanceFunction,
20}
21
22/* PgVector supported distances
23<-> - L2 distance
24<#> - (negative) inner product
25<=> - cosine distance
26<+> - L1 distance (added in 0.7.0)
27<~> - Hamming distance (binary vectors, added in 0.7.0)
28<%> - Jaccard distance (binary vectors, added in 0.7.0)
29 */
30pub enum PgVectorDistanceFunction {
31    L2,
32    InnerProduct,
33    Cosine,
34    L1,
35    Hamming,
36    Jaccard,
37}
38
39impl Display for PgVectorDistanceFunction {
40    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
41        match self {
42            PgVectorDistanceFunction::L2 => write!(f, "<->"),
43            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
44            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
45            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
46            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
47            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
48        }
49    }
50}
51
52#[derive(Debug, Deserialize, sqlx::FromRow)]
53pub struct SearchResult {
54    id: Uuid,
55    document: Value,
56    //embedded_text: String,
57    distance: f64,
58}
59
60#[derive(Debug, Deserialize, sqlx::FromRow)]
61pub struct SearchResultOnlyId {
62    id: Uuid,
63    distance: f64,
64}
65
66impl SearchResult {
67    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
68        let document: T =
69            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
70        Ok((self.distance, self.id.to_string(), document))
71    }
72}
73
74impl<Model: EmbeddingModel> PostgresVectorStore<Model> {
75    pub fn new(
76        model: Model,
77        pg_pool: PgPool,
78        documents_table: Option<String>,
79        distance_function: PgVectorDistanceFunction,
80    ) -> Self {
81        Self {
82            model,
83            pg_pool,
84            documents_table: documents_table.unwrap_or(String::from("documents")),
85            distance_function,
86        }
87    }
88
89    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
90        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
91    }
92
93    fn search_query_full(&self) -> String {
94        self.search_query(true)
95    }
96    fn search_query_only_ids(&self) -> String {
97        self.search_query(false)
98    }
99
100    fn search_query(&self, with_document: bool) -> String {
101        let document = if with_document { ", document" } else { "" };
102        format!(
103            "
104            SELECT id{}, distance FROM ( \
105              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
106              FROM {} \
107              ORDER BY id, distance \
108            ) as d \
109            ORDER BY distance \
110            LIMIT $2",
111            document, document, self.distance_function, self.documents_table
112        )
113    }
114}
115
116impl<Model> InsertDocuments for PostgresVectorStore<Model>
117where
118    Model: EmbeddingModel + Send + Sync,
119{
120    async fn insert_documents<Doc: Serialize + Embed + Send>(
121        &self,
122        documents: Vec<(Doc, OneOrMany<Embedding>)>,
123    ) -> Result<(), VectorStoreError> {
124        for (document, embeddings) in documents {
125            let id = Uuid::new_v4();
126            let json_document = serde_json::to_value(&document).unwrap();
127
128            for embedding in embeddings {
129                let embedding_text = embedding.document;
130                let embedding: Vec<f64> = embedding.vec;
131
132                sqlx::query(
133                    format!(
134                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
135                        self.documents_table
136                    )
137                    .as_str(),
138                )
139                .bind(id)
140                .bind(&json_document)
141                .bind(&embedding_text)
142                .bind(&embedding)
143                .execute(&self.pg_pool)
144                .await
145                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
146            }
147        }
148
149        Ok(())
150    }
151}
152
153impl<Model: EmbeddingModel> VectorStoreIndex for PostgresVectorStore<Model> {
154    /// Get the top n documents based on the distance to the given query.
155    /// The result is a list of tuples of the form (score, id, document)
156    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
157        &self,
158        req: VectorSearchRequest,
159    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
160        if req.samples() > i64::MAX as u64 {
161            return Err(VectorStoreError::DatastoreError(
162                format!(
163                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
164                    i64::MAX
165                )
166                .into(),
167            ));
168        }
169
170        let embedded_query: pgvector::Vector = self
171            .model
172            .embed_text(req.query())
173            .await?
174            .vec
175            .iter()
176            .map(|&x| x as f32)
177            .collect::<Vec<f32>>()
178            .into();
179
180        let rows: Vec<SearchResult> = sqlx::query_as(self.search_query_full().as_str())
181            .bind(embedded_query)
182            .bind(req.samples() as i64)
183            .fetch_all(&self.pg_pool)
184            .await
185            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
186
187        let rows: Vec<(f64, String, T)> = rows
188            .into_iter()
189            .flat_map(SearchResult::into_result)
190            .collect();
191
192        Ok(rows)
193    }
194
195    /// Same as `top_n` but returns the document ids only.
196    async fn top_n_ids(
197        &self,
198        req: VectorSearchRequest,
199    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
200        if req.samples() > i64::MAX as u64 {
201            return Err(VectorStoreError::DatastoreError(
202                format!(
203                    "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
204                    i64::MAX
205                )
206                .into(),
207            ));
208        }
209        let embedded_query: pgvector::Vector = self
210            .model
211            .embed_text(req.query())
212            .await?
213            .vec
214            .iter()
215            .map(|&x| x as f32)
216            .collect::<Vec<f32>>()
217            .into();
218
219        let rows: Vec<SearchResultOnlyId> = sqlx::query_as(self.search_query_only_ids().as_str())
220            .bind(embedded_query)
221            .bind(req.samples() as i64)
222            .fetch_all(&self.pg_pool)
223            .await
224            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
225
226        let rows: Vec<(f64, String)> = rows
227            .into_iter()
228            .map(|row| (row.distance, row.id.to_string()))
229            .collect();
230
231        Ok(rows)
232    }
233}