1use std::fmt::Display;
2
3use rig::{
4 embeddings::{Embedding, EmbeddingModel},
5 vector_store::{VectorStoreError, VectorStoreIndex},
6 Embed, OneOrMany,
7};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10use sqlx::PgPool;
11use uuid::Uuid;
12
13pub struct PostgresVectorStore<Model: EmbeddingModel> {
14 model: Model,
15 pg_pool: PgPool,
16 documents_table: String,
17 distance_function: PgVectorDistanceFunction,
18}
19
20pub enum PgVectorDistanceFunction {
29 L2,
30 InnerProduct,
31 Cosine,
32 L1,
33 Hamming,
34 Jaccard,
35}
36
37impl Display for PgVectorDistanceFunction {
38 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
39 match self {
40 PgVectorDistanceFunction::L2 => write!(f, "<->"),
41 PgVectorDistanceFunction::InnerProduct => write!(f, "<#>"),
42 PgVectorDistanceFunction::Cosine => write!(f, "<=>"),
43 PgVectorDistanceFunction::L1 => write!(f, "<+>"),
44 PgVectorDistanceFunction::Hamming => write!(f, "<~>"),
45 PgVectorDistanceFunction::Jaccard => write!(f, "<%>"),
46 }
47 }
48}
49
50#[derive(Debug, Deserialize, sqlx::FromRow)]
51pub struct SearchResult {
52 id: Uuid,
53 document: Value,
54 distance: f64,
56}
57
58#[derive(Debug, Deserialize, sqlx::FromRow)]
59pub struct SearchResultOnlyId {
60 id: Uuid,
61 distance: f64,
62}
63
64impl SearchResult {
65 pub fn into_result<T: DeserializeOwned>(self) -> Result<(f64, String, T), VectorStoreError> {
66 let document: T =
67 serde_json::from_value(self.document).map_err(VectorStoreError::JsonError)?;
68 Ok((self.distance, self.id.to_string(), document))
69 }
70}
71
72impl<Model: EmbeddingModel> PostgresVectorStore<Model> {
73 pub fn new(
74 model: Model,
75 pg_pool: PgPool,
76 documents_table: Option<String>,
77 distance_function: PgVectorDistanceFunction,
78 ) -> Self {
79 Self {
80 model,
81 pg_pool,
82 documents_table: documents_table.unwrap_or(String::from("documents")),
83 distance_function,
84 }
85 }
86
87 pub fn with_defaults(model: Model, pg_pool: PgPool) -> Self {
88 Self::new(model, pg_pool, None, PgVectorDistanceFunction::Cosine)
89 }
90
91 fn search_query_full(&self) -> String {
92 self.search_query(true)
93 }
94 fn search_query_only_ids(&self) -> String {
95 self.search_query(false)
96 }
97
98 fn search_query(&self, with_document: bool) -> String {
99 let document = if with_document { ", document" } else { "" };
100 format!(
101 "
102 SELECT id{}, distance FROM ( \
103 SELECT DISTINCT ON (id) id{}, embedding {} $1 as distance \
104 FROM {} \
105 ORDER BY id, distance \
106 ) as d \
107 ORDER BY distance \
108 LIMIT $2",
109 document, document, self.distance_function, self.documents_table
110 )
111 }
112
113 pub async fn insert_documents<Doc: Serialize + Embed + Send>(
114 &self,
115 documents: Vec<(Doc, OneOrMany<Embedding>)>,
116 ) -> Result<(), VectorStoreError> {
117 for (document, embeddings) in documents {
118 let id = Uuid::new_v4();
119 let json_document = serde_json::to_value(&document).unwrap();
120
121 for embedding in embeddings {
122 let embedding_text = embedding.document;
123 let embedding: Vec<f64> = embedding.vec;
124
125 sqlx::query(
126 format!(
127 "INSERT INTO {} (id, document, embedded_text, embedding) VALUES ($1, $2, $3, $4)",
128 self.documents_table
129 )
130 .as_str(),
131 )
132 .bind(id)
133 .bind(&json_document)
134 .bind(&embedding_text)
135 .bind(&embedding)
136 .execute(&self.pg_pool)
137 .await
138 .map_err(|e| VectorStoreError::DatastoreError(e.into()))?;
139 }
140 }
141
142 Ok(())
143 }
144}
145
146impl<Model: EmbeddingModel> VectorStoreIndex for PostgresVectorStore<Model> {
147 async fn top_n<T: for<'a> Deserialize<'a> + Send>(
150 &self,
151 query: &str,
152 n: usize,
153 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
154 let embedded_query: pgvector::Vector = self
155 .model
156 .embed_text(query)
157 .await?
158 .vec
159 .iter()
160 .map(|&x| x as f32)
161 .collect::<Vec<f32>>()
162 .into();
163
164 let rows: Vec<SearchResult> = sqlx::query_as(self.search_query_full().as_str())
165 .bind(embedded_query)
166 .bind(n as i64)
167 .fetch_all(&self.pg_pool)
168 .await
169 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
170
171 let rows: Vec<(f64, String, T)> = rows
172 .into_iter()
173 .flat_map(SearchResult::into_result)
174 .collect();
175
176 Ok(rows)
177 }
178
179 async fn top_n_ids(
181 &self,
182 query: &str,
183 n: usize,
184 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
185 let embedded_query: pgvector::Vector = self
186 .model
187 .embed_text(query)
188 .await?
189 .vec
190 .iter()
191 .map(|&x| x as f32)
192 .collect::<Vec<f32>>()
193 .into();
194
195 let rows: Vec<SearchResultOnlyId> = sqlx::query_as(self.search_query_only_ids().as_str())
196 .bind(embedded_query)
197 .bind(n as i64)
198 .fetch_all(&self.pg_pool)
199 .await
200 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
201
202 let rows: Vec<(f64, String)> = rows
203 .into_iter()
204 .map(|row| (row.distance, row.id.to_string()))
205 .collect();
206
207 Ok(rows)
208 }
209}