1use std::fmt::Display;
2
3use rig::{
4 Embed, OneOrMany,
5 embeddings::{Embedding, EmbeddingModel},
6 vector_store::{
7 InsertDocuments, VectorStoreError, VectorStoreIndex, request::VectorSearchRequest,
8 },
9};
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11use serde_json::Value;
12use sqlx::PgPool;
13use uuid::Uuid;
14
15pub struct PostgresVectorStore<Model: EmbeddingModel> {
16 model: Model,
17 pg_pool: PgPool,
18 documents_table: String,
19 distance_function: PgVectorDistanceFunction,
20}
21
22pub enum PgVectorDistanceFunction {
31 L2,
32 InnerProduct,
33 Cosine,
34 L1,
35 Hamming,
36 Jaccard,
37}
38
39impl Display for PgVectorDistanceFunction {
40 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
41 match self {
42 PgVectorDistanceFunction::L2 => write!(f, "<->"),
43 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
44 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
45 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
46 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
47 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
48 }
49 }
50}
51
52#[derive(Debug, Deserialize, sqlx::FromRow)]
53pub struct SearchResult {
54 id: Uuid,
55 document: Value,
56 distance: f64,
58}
59
60#[derive(Debug, Deserialize, sqlx::FromRow)]
61pub struct SearchResultOnlyId {
62 id: Uuid,
63 distance: f64,
64}
65
66impl SearchResult {
67 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
68 let document: T =
69 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
70 Ok((self.distance, self.id.to_string(), document))
71 }
72}
73
74impl<Model: EmbeddingModel> PostgresVectorStore<Model> {
75 pub fn new(
76 model: Model,
77 pg_pool: PgPool,
78 documents_table: Option<String>,
79 distance_function: PgVectorDistanceFunction,
80 ) -> Self {
81 Self {
82 model,
83 pg_pool,
84 documents_table: documents_table.unwrap_or(String::from("documents")),
85 distance_function,
86 }
87 }
88
89 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
90 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
91 }
92
93 fn search_query_full(&self, threshold: Option<f64>) -> String {
94 self.search_query(true, threshold)
95 }
96 fn search_query_only_ids(&self, threshold: Option<f64>) -> String {
97 self.search_query(false, threshold)
98 }
99
100 fn search_query(&self, with_document: bool, threshold: Option<f64>) -> String {
101 let document = if with_document { ", document" } else { "" };
102 format!(
103 "
104 SELECT id{}, distance FROM ( \
105 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
106 FROM {} \
107 {where_clause}
108 ORDER BY id, distance \
109 ) as d \
110 ORDER BY distance \
111 LIMIT $2",
112 document,
113 document,
114 self.distance_function,
115 self.documents_table,
116 where_clause = if let Some(threshold) = threshold {
117 format!("where distance > {threshold}")
118 } else {
119 String::new()
120 }
121 )
122 }
123}
124
125impl<Model> InsertDocuments for PostgresVectorStore<Model>
126where
127 Model: EmbeddingModel + Send + Sync,
128{
129 async fn insert_documents<Doc: Serialize + Embed + Send>(
130 &self,
131 documents: Vec<(Doc, OneOrMany<Embedding>)>,
132 ) -> Result<(), VectorStoreError> {
133 for (document, embeddings) in documents {
134 let id = Uuid::new_v4();
135 let json_document = serde_json::to_value(&document).unwrap();
136
137 for embedding in embeddings {
138 let embedding_text = embedding.document;
139 let embedding: Vec<f64> = embedding.vec;
140
141 sqlx::query(
142 format!(
143 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
144 self.documents_table
145 )
146 .as_str(),
147 )
148 .bind(id)
149 .bind(&json_document)
150 .bind(&embedding_text)
151 .bind(&embedding)
152 .execute(&self.pg_pool)
153 .await
154 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
155 }
156 }
157
158 Ok(())
159 }
160}
161
162impl<Model: EmbeddingModel> VectorStoreIndex for PostgresVectorStore<Model> {
163 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
166 &self,
167 req: VectorSearchRequest,
168 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
169 if req.samples() > i64::MAX as u64 {
170 return Err(VectorStoreError::DatastoreError(
171 format!(
172 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
173 i64::MAX
174 )
175 .into(),
176 ));
177 }
178
179 let embedded_query: pgvector::Vector = self
180 .model
181 .embed_text(req.query())
182 .await?
183 .vec
184 .iter()
185 .map(|&x| x as f32)
186 .collect::<Vec<f32>>()
187 .into();
188
189 let rows: Vec<SearchResult> =
190 sqlx::query_as(self.search_query_full(req.threshold()).as_str())
191 .bind(embedded_query)
192 .bind(req.samples() as i64)
193 .fetch_all(&self.pg_pool)
194 .await
195 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
196
197 let rows: Vec<(f64, String, T)> = rows
198 .into_iter()
199 .flat_map(SearchResult::into_result)
200 .collect();
201
202 Ok(rows)
203 }
204
205 async fn top_n_ids(
207 &self,
208 req: VectorSearchRequest,
209 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
210 if req.samples() > i64::MAX as u64 {
211 return Err(VectorStoreError::DatastoreError(
212 format!(
213 "The maximum amount of samples to return with the `rig` Postgres integration cannot be larger than {}",
214 i64::MAX
215 )
216 .into(),
217 ));
218 }
219 let embedded_query: pgvector::Vector = self
220 .model
221 .embed_text(req.query())
222 .await?
223 .vec
224 .iter()
225 .map(|&x| x as f32)
226 .collect::<Vec<f32>>()
227 .into();
228
229 let rows: Vec<SearchResultOnlyId> =
230 sqlx::query_as(self.search_query_only_ids(req.threshold()).as_str())
231 .bind(embedded_query)
232 .bind(req.samples() as i64)
233 .fetch_all(&self.pg_pool)
234 .await
235 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
236
237 let rows: Vec<(f64, String)> = rows
238 .into_iter()
239 .map(|row| (row.distance, row.id.to_string()))
240 .collect();
241
242 Ok(rows)
243 }
244}