daimon_plugin_pgvector/
store.rs1use std::collections::HashMap;
4
5use deadpool_postgres::Pool;
6use daimon_core::vector_store::VectorStore;
7use daimon_core::{DaimonError, Document, Result, ScoredDocument};
8use pgvector::Vector;
9
10use crate::DistanceMetric;
11
12pub struct PgVectorStore {
16 pub(crate) pool: Pool,
17 pub(crate) table: String,
18 pub(crate) dimensions: usize,
19 pub(crate) distance_metric: DistanceMetric,
20}
21
22impl PgVectorStore {
23 fn distance_operator(&self) -> &'static str {
25 match self.distance_metric {
26 DistanceMetric::Cosine => "<=>",
27 DistanceMetric::L2 => "<->",
28 DistanceMetric::InnerProduct => "<#>",
29 }
30 }
31
32 pub fn pool(&self) -> &Pool {
34 &self.pool
35 }
36
37 pub fn table(&self) -> &str {
39 &self.table
40 }
41
42 pub fn dimensions(&self) -> usize {
44 self.dimensions
45 }
46}
47
48impl VectorStore for PgVectorStore {
49 async fn upsert(&self, id: &str, embedding: Vec<f32>, document: Document) -> Result<()> {
50 if embedding.len() != self.dimensions {
51 return Err(DaimonError::Other(format!(
52 "embedding dimension mismatch: expected {}, got {}",
53 self.dimensions,
54 embedding.len()
55 )));
56 }
57
58 let client = self.pool.get().await.map_err(|e| {
59 DaimonError::Other(format!("pgvector pool error: {e}"))
60 })?;
61
62 let vec = Vector::from(embedding);
63 let metadata = serde_json::to_value(&document.metadata)
64 .map_err(|e| DaimonError::Other(format!("metadata serialization error: {e}")))?;
65
66 let sql = format!(
67 "INSERT INTO {} (id, embedding, content, metadata) VALUES ($1, $2, $3, $4) \
68 ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, \
69 content = EXCLUDED.content, metadata = EXCLUDED.metadata",
70 self.table
71 );
72
73 client
74 .execute(&sql as &str, &[&id, &vec, &document.content, &metadata])
75 .await
76 .map_err(|e| DaimonError::Other(format!("pgvector upsert error: {e}")))?;
77
78 Ok(())
79 }
80
81 async fn query(&self, embedding: Vec<f32>, top_k: usize) -> Result<Vec<ScoredDocument>> {
82 if embedding.len() != self.dimensions {
83 return Err(DaimonError::Other(format!(
84 "embedding dimension mismatch: expected {}, got {}",
85 self.dimensions,
86 embedding.len()
87 )));
88 }
89
90 let client = self.pool.get().await.map_err(|e| {
91 DaimonError::Other(format!("pgvector pool error: {e}"))
92 })?;
93
94 let vec = Vector::from(embedding);
95 let op = self.distance_operator();
96
97 let score_expr = match self.distance_metric {
100 DistanceMetric::Cosine | DistanceMetric::L2 => {
101 format!("1.0 - (embedding {op} $1)")
102 }
103 DistanceMetric::InnerProduct => {
104 format!("-(embedding {op} $1)")
105 }
106 };
107
108 let sql = format!(
109 "SELECT id, content, metadata, {score_expr} AS score \
110 FROM {} ORDER BY embedding {op} $1 LIMIT $2",
111 self.table
112 );
113
114 let rows = client
115 .query(&sql as &str, &[&vec, &(top_k as i64)])
116 .await
117 .map_err(|e| DaimonError::Other(format!("pgvector query error: {e}")))?;
118
119 let mut results = Vec::with_capacity(rows.len());
120 for row in rows {
121 let content: String = row.get("content");
122 let metadata_val: serde_json::Value = row.get("metadata");
123 let score: f64 = row.get("score");
124
125 let metadata: HashMap<String, serde_json::Value> =
126 serde_json::from_value(metadata_val).unwrap_or_default();
127
128 let doc = Document {
129 content,
130 metadata,
131 score: Some(score),
132 };
133 results.push(ScoredDocument::new(doc, score));
134 }
135
136 Ok(results)
137 }
138
139 async fn delete(&self, id: &str) -> Result<bool> {
140 let client = self.pool.get().await.map_err(|e| {
141 DaimonError::Other(format!("pgvector pool error: {e}"))
142 })?;
143
144 let sql = format!("DELETE FROM {} WHERE id = $1", self.table);
145 let deleted = client
146 .execute(&sql as &str, &[&id])
147 .await
148 .map_err(|e| DaimonError::Other(format!("pgvector delete error: {e}")))?;
149
150 Ok(deleted > 0)
151 }
152
153 async fn count(&self) -> Result<usize> {
154 let client = self.pool.get().await.map_err(|e| {
155 DaimonError::Other(format!("pgvector pool error: {e}"))
156 })?;
157
158 let sql = format!("SELECT COUNT(*) AS cnt FROM {}", self.table);
159 let row = client
160 .query_one(&sql as &str, &[])
161 .await
162 .map_err(|e| DaimonError::Other(format!("pgvector count error: {e}")))?;
163
164 let count: i64 = row.get("cnt");
165 Ok(count as usize)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_distance_operator() {
175 let store = PgVectorStore {
176 pool: create_dummy_pool(),
177 table: "t".into(),
178 dimensions: 3,
179 distance_metric: DistanceMetric::Cosine,
180 };
181 assert_eq!(store.distance_operator(), "<=>");
182
183 let store = PgVectorStore {
184 dimensions: 3,
185 distance_metric: DistanceMetric::L2,
186 ..store
187 };
188 assert_eq!(store.distance_operator(), "<->");
189
190 let store = PgVectorStore {
191 distance_metric: DistanceMetric::InnerProduct,
192 ..store
193 };
194 assert_eq!(store.distance_operator(), "<#>");
195 }
196
197 fn create_dummy_pool() -> Pool {
198 let cfg = deadpool_postgres::Config {
199 host: Some("localhost".into()),
200 port: Some(5432),
201 dbname: Some("test".into()),
202 ..Default::default()
203 };
204 cfg.create_pool(None, tokio_postgres::NoTls).unwrap()
205 }
206}