rig_postgres/
lib.rs

1use std::fmt::Display;
2
3use rig::{
4    Embed, OneOrMany,
5    embeddings::{Embedding, EmbeddingModel},
6    vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex},
7};
8use serde::{Deserialize, Serialize, de::DeserializeOwned};
9use serde_json::Value;
10use sqlx::PgPool;
11use uuid::Uuid;
12
13pub struct PostgresVectorStore<Model: EmbeddingModel> {
14    model: Model,
15    pg_pool: PgPool,
16    documents_table: String,
17    distance_function: PgVectorDistanceFunction,
18}
19
20/* PgVector supported distances
21<-> - L2 distance
22<#> - (negative) inner product
23<=> - cosine distance
24<+> - L1 distance (added in 0.7.0)
25<~> - Hamming distance (binary vectors, added in 0.7.0)
26<%> - Jaccard distance (binary vectors, added in 0.7.0)
27 */
28pub enum PgVectorDistanceFunction {
29    L2,
30    InnerProduct,
31    Cosine,
32    L1,
33    Hamming,
34    Jaccard,
35}
36
37impl Display for PgVectorDistanceFunction {
38    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
39        match self {
40            PgVectorDistanceFunction::L2 => write!(f, "<->"),
41            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
42            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
43            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
44            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
45            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
46        }
47    }
48}
49
50#[derive(Debug, Deserialize, sqlx::FromRow)]
51pub struct SearchResult {
52    id: Uuid,
53    document: Value,
54    //embedded_text: String,
55    distance: f64,
56}
57
58#[derive(Debug, Deserialize, sqlx::FromRow)]
59pub struct SearchResultOnlyId {
60    id: Uuid,
61    distance: f64,
62}
63
64impl SearchResult {
65    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
66        let document: T =
67            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
68        Ok((self.distance, self.id.to_string(), document))
69    }
70}
71
72impl<Model: EmbeddingModel> PostgresVectorStore<Model> {
73    pub fn new(
74        model: Model,
75        pg_pool: PgPool,
76        documents_table: Option<String>,
77        distance_function: PgVectorDistanceFunction,
78    ) -> Self {
79        Self {
80            model,
81            pg_pool,
82            documents_table: documents_table.unwrap_or(String::from("documents")),
83            distance_function,
84        }
85    }
86
87    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
88        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
89    }
90
91    fn search_query_full(&self) -> String {
92        self.search_query(true)
93    }
94    fn search_query_only_ids(&self) -> String {
95        self.search_query(false)
96    }
97
98    fn search_query(&self, with_document: bool) -> String {
99        let document = if with_document { ", document" } else { "" };
100        format!(
101            "
102            SELECT id{}, distance FROM ( \
103              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
104              FROM {} \
105              ORDER BY id, distance \
106            ) as d \
107            ORDER BY distance \
108            LIMIT $2",
109            document, document, self.distance_function, self.documents_table
110        )
111    }
112}
113
114impl<Model> InsertDocuments for PostgresVectorStore<Model>
115where
116    Model: EmbeddingModel + Send + Sync,
117{
118    async fn insert_documents<Doc: Serialize + Embed + Send>(
119        &self,
120        documents: Vec<(Doc, OneOrMany<Embedding>)>,
121    ) -> Result<(), VectorStoreError> {
122        for (document, embeddings) in documents {
123            let id = Uuid::new_v4();
124            let json_document = serde_json::to_value(&document).unwrap();
125
126            for embedding in embeddings {
127                let embedding_text = embedding.document;
128                let embedding: Vec<f64> = embedding.vec;
129
130                sqlx::query(
131                    format!(
132                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
133                        self.documents_table
134                    )
135                    .as_str(),
136                )
137                .bind(id)
138                .bind(&json_document)
139                .bind(&embedding_text)
140                .bind(&embedding)
141                .execute(&self.pg_pool)
142                .await
143                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
144            }
145        }
146
147        Ok(())
148    }
149}
150
151impl<Model: EmbeddingModel> VectorStoreIndex for PostgresVectorStore<Model> {
152    /// Get the top n documents based on the distance to the given query.
153    /// The result is a list of tuples of the form (score, id, document)
154    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
155        &self,
156        query: &str,
157        n: usize,
158    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
159        let embedded_query: pgvector::Vector = self
160            .model
161            .embed_text(query)
162            .await?
163            .vec
164            .iter()
165            .map(|&x| x as f32)
166            .collect::<Vec<f32>>()
167            .into();
168
169        let rows: Vec<SearchResult> = sqlx::query_as(self.search_query_full().as_str())
170            .bind(embedded_query)
171            .bind(n as i64)
172            .fetch_all(&self.pg_pool)
173            .await
174            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
175
176        let rows: Vec<(f64, String, T)> = rows
177            .into_iter()
178            .flat_map(SearchResult::into_result)
179            .collect();
180
181        Ok(rows)
182    }
183
184    /// Same as `top_n` but returns the document ids only.
185    async fn top_n_ids(
186        &self,
187        query: &str,
188        n: usize,
189    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
190        let embedded_query: pgvector::Vector = self
191            .model
192            .embed_text(query)
193            .await?
194            .vec
195            .iter()
196            .map(|&x| x as f32)
197            .collect::<Vec<f32>>()
198            .into();
199
200        let rows: Vec<SearchResultOnlyId> = sqlx::query_as(self.search_query_only_ids().as_str())
201            .bind(embedded_query)
202            .bind(n as i64)
203            .fetch_all(&self.pg_pool)
204            .await
205            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
206
207        let rows: Vec<(f64, String)> = rows
208            .into_iter()
209            .map(|row| (row.distance, row.id.to_string()))
210            .collect();
211
212        Ok(rows)
213    }
214}