1use std::fmt::Display;
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex},
7};
8use serde::{Deserialize, Serialize, de::DeserializeOwned};
9use serde_json::Value;
10use sqlx::PgPool;
11use uuid::Uuid;
12
13pub struct PostgresVectorStore<Model: EmbeddingModel> {
14 model: Model,
15 pg_pool: PgPool,
16 documents_table: String,
17 distance_function: PgVectorDistanceFunction,
18}
19
20pub enum PgVectorDistanceFunction {
29 L2,
30 InnerProduct,
31 Cosine,
32 L1,
33 Hamming,
34 Jaccard,
35}
36
37impl Display for PgVectorDistanceFunction {
38 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
39 match self {
40 PgVectorDistanceFunction::L2 => write!(f, "<->"),
41 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
42 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
43 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
44 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
45 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
46 }
47 }
48}
49
50#[derive(Debug, Deserialize, sqlx::FromRow)]
51pub struct SearchResult {
52 id: Uuid,
53 document: Value,
54 distance: f64,
56}
57
58#[derive(Debug, Deserialize, sqlx::FromRow)]
59pub struct SearchResultOnlyId {
60 id: Uuid,
61 distance: f64,
62}
63
64impl SearchResult {
65 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
66 let document: T =
67 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
68 Ok((self.distance, self.id.to_string(), document))
69 }
70}
71
72impl<Model: EmbeddingModel> PostgresVectorStore<Model> {
73 pub fn new(
74 model: Model,
75 pg_pool: PgPool,
76 documents_table: Option<String>,
77 distance_function: PgVectorDistanceFunction,
78 ) -> Self {
79 Self {
80 model,
81 pg_pool,
82 documents_table: documents_table.unwrap_or(String::from("documents")),
83 distance_function,
84 }
85 }
86
87 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
88 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
89 }
90
91 fn search_query_full(&self) -> String {
92 self.search_query(true)
93 }
94 fn search_query_only_ids(&self) -> String {
95 self.search_query(false)
96 }
97
98 fn search_query(&self, with_document: bool) -> String {
99 let document = if with_document { ", document" } else { "" };
100 format!(
101 "
102 SELECT id{}, distance FROM ( \
103 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
104 FROM {} \
105 ORDER BY id, distance \
106 ) as d \
107 ORDER BY distance \
108 LIMIT $2",
109 document, document, self.distance_function, self.documents_table
110 )
111 }
112}
113
114impl<Model> InsertDocuments for PostgresVectorStore<Model>
115where
116 Model: EmbeddingModel + Send + Sync,
117{
118 async fn insert_documents<Doc: Serialize + Embed + Send>(
119 &self,
120 documents: Vec<(Doc, OneOrMany<Embedding>)>,
121 ) -> Result<(), VectorStoreError> {
122 for (document, embeddings) in documents {
123 let id = Uuid::new_v4();
124 let json_document = serde_json::to_value(&document).unwrap();
125
126 for embedding in embeddings {
127 let embedding_text = embedding.document;
128 let embedding: Vec<f64> = embedding.vec;
129
130 sqlx::query(
131 format!(
132 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
133 self.documents_table
134 )
135 .as_str(),
136 )
137 .bind(id)
138 .bind(&json_document)
139 .bind(&embedding_text)
140 .bind(&embedding)
141 .execute(&self.pg_pool)
142 .await
143 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
144 }
145 }
146
147 Ok(())
148 }
149}
150
151impl<Model: EmbeddingModel> VectorStoreIndex for PostgresVectorStore<Model> {
152 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
155 &self,
156 query: &str,
157 n: usize,
158 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
159 let embedded_query: pgvector::Vector = self
160 .model
161 .embed_text(query)
162 .await?
163 .vec
164 .iter()
165 .map(|&x| x as f32)
166 .collect::<Vec<f32>>()
167 .into();
168
169 let rows: Vec<SearchResult> = sqlx::query_as(self.search_query_full().as_str())
170 .bind(embedded_query)
171 .bind(n as i64)
172 .fetch_all(&self.pg_pool)
173 .await
174 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
175
176 let rows: Vec<(f64, String, T)> = rows
177 .into_iter()
178 .flat_map(SearchResult::into_result)
179 .collect();
180
181 Ok(rows)
182 }
183
184 async fn top_n_ids(
186 &self,
187 query: &str,
188 n: usize,
189 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
190 let embedded_query: pgvector::Vector = self
191 .model
192 .embed_text(query)
193 .await?
194 .vec
195 .iter()
196 .map(|&x| x as f32)
197 .collect::<Vec<f32>>()
198 .into();
199
200 let rows: Vec<SearchResultOnlyId> = sqlx::query_as(self.search_query_only_ids().as_str())
201 .bind(embedded_query)
202 .bind(n as i64)
203 .fetch_all(&self.pg_pool)
204 .await
205 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
206
207 let rows: Vec<(f64, String)> = rows
208 .into_iter()
209 .map(|row| (row.distance, row.id.to_string()))
210 .collect();
211
212 Ok(rows)
213 }
214}