rig_postgres/
lib.rs

1use std::fmt::Display;
2
3use rig::{
4    embeddings::{Embedding, EmbeddingModel},
5    vector_store::{VectorStoreError, VectorStoreIndex},
6    Embed, OneOrMany,
7};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10use sqlx::PgPool;
11use uuid::Uuid;
12
13pub struct PostgresVectorStore<Model: EmbeddingModel> {
14    model: Model,
15    pg_pool: PgPool,
16    documents_table: String,
17    distance_function: PgVectorDistanceFunction,
18}
19
20/* PgVector supported distances
21<-> - L2 distance
22<#> - (negative) inner product
23<=> - cosine distance
24<+> - L1 distance (added in 0.7.0)
25<~> - Hamming distance (binary vectors, added in 0.7.0)
26<%> - Jaccard distance (binary vectors, added in 0.7.0)
27 */
28pub enum PgVectorDistanceFunction {
29    L2,
30    InnerProduct,
31    Cosine,
32    L1,
33    Hamming,
34    Jaccard,
35}
36
37impl Display for PgVectorDistanceFunction {
38    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
39        match self {
40            PgVectorDistanceFunction::L2 => write!(f, "<->"),
41            PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
42            PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
43            PgVectorDistanceFunction::L1 => write!(f, "<+>"),
44            PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
45            PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
46        }
47    }
48}
49
50#[derive(Debug, Deserialize, sqlx::FromRow)]
51pub struct SearchResult {
52    id: Uuid,
53    document: Value,
54    //embedded_text: String,
55    distance: f64,
56}
57
58#[derive(Debug, Deserialize, sqlx::FromRow)]
59pub struct SearchResultOnlyId {
60    id: Uuid,
61    distance: f64,
62}
63
64impl SearchResult {
65    pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
66        let document: T =
67            serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
68        Ok((self.distance, self.id.to_string(), document))
69    }
70}
71
72impl<Model: EmbeddingModel> PostgresVectorStore<Model> {
73    pub fn new(
74        model: Model,
75        pg_pool: PgPool,
76        documents_table: Option<String>,
77        distance_function: PgVectorDistanceFunction,
78    ) -> Self {
79        Self {
80            model,
81            pg_pool,
82            documents_table: documents_table.unwrap_or(String::from("documents")),
83            distance_function,
84        }
85    }
86
87    pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
88        Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
89    }
90
91    fn search_query_full(&self) -> String {
92        self.search_query(true)
93    }
94    fn search_query_only_ids(&self) -> String {
95        self.search_query(false)
96    }
97
98    fn search_query(&self, with_document: bool) -> String {
99        let document = if with_document { ", document" } else { "" };
100        format!(
101            "
102            SELECT id{}, distance FROM ( \
103              SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
104              FROM {} \
105              ORDER BY id, distance \
106            ) as d \
107            ORDER BY distance \
108            LIMIT $2",
109            document, document, self.distance_function, self.documents_table
110        )
111    }
112
113    pub async fn insert_documents<Doc: Serialize + Embed + Send>(
114        &self,
115        documents: Vec<(Doc, OneOrMany<Embedding>)>,
116    ) -> Result<(), VectorStoreError> {
117        for (document, embeddings) in documents {
118            let id = Uuid::new_v4();
119            let json_document = serde_json::to_value(&document).unwrap();
120
121            for embedding in embeddings {
122                let embedding_text = embedding.document;
123                let embedding: Vec<f64> = embedding.vec;
124
125                sqlx::query(
126                    format!(
127                        "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
128                        self.documents_table
129                    )
130                    .as_str(),
131                )
132                .bind(id)
133                .bind(&json_document)
134                .bind(&embedding_text)
135                .bind(&embedding)
136                .execute(&self.pg_pool)
137                .await
138                .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
139            }
140        }
141
142        Ok(())
143    }
144}
145
146impl<Model: EmbeddingModel> VectorStoreIndex for PostgresVectorStore<Model> {
147    /// Get the top n documents based on the distance to the given query.
148    /// The result is a list of tuples of the form (score, id, document)
149    async fn top_n<T: for<'a> Deserialize<'a> + Send>(
150        &self,
151        query: &str,
152        n: usize,
153    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
154        let embedded_query: pgvector::Vector = self
155            .model
156            .embed_text(query)
157            .await?
158            .vec
159            .iter()
160            .map(|&x| x as f32)
161            .collect::<Vec<f32>>()
162            .into();
163
164        let rows: Vec<SearchResult> = sqlx::query_as(self.search_query_full().as_str())
165            .bind(embedded_query)
166            .bind(n as i64)
167            .fetch_all(&self.pg_pool)
168            .await
169            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
170
171        let rows: Vec<(f64, String, T)> = rows
172            .into_iter()
173            .flat_map(SearchResult::into_result)
174            .collect();
175
176        Ok(rows)
177    }
178
179    /// Same as `top_n` but returns the document ids only.
180    async fn top_n_ids(
181        &self,
182        query: &str,
183        n: usize,
184    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
185        let embedded_query: pgvector::Vector = self
186            .model
187            .embed_text(query)
188            .await?
189            .vec
190            .iter()
191            .map(|&x| x as f32)
192            .collect::<Vec<f32>>()
193            .into();
194
195        let rows: Vec<SearchResultOnlyId> = sqlx::query_as(self.search_query_only_ids().as_str())
196            .bind(embedded_query)
197            .bind(n as i64)
198            .fetch_all(&self.pg_pool)
199            .await
200            .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
201
202        let rows: Vec<(f64, String)> = rows
203            .into_iter()
204            .map(|row| (row.distance, row.id.to_string()))
205            .collect();
206
207        Ok(rows)
208    }
209}