Skip to main content

daimon_plugin_pgvector/
store.rs

1//! [`PgVectorStore`] — a pgvector-backed [`VectorStore`] implementation.
2
3use std::collections::HashMap;
4
5use deadpool_postgres::Pool;
6use daimon_core::vector_store::VectorStore;
7use daimon_core::{DaimonError, Document, Result, ScoredDocument};
8use pgvector::Vector;
9
10use crate::DistanceMetric;
11
12/// A [`VectorStore`] backed by PostgreSQL with the pgvector extension.
13///
14/// Use [`PgVectorStoreBuilder`](crate::PgVectorStoreBuilder) to construct.
15pub struct PgVectorStore {
16    pub(crate) pool: Pool,
17    pub(crate) table: String,
18    pub(crate) dimensions: usize,
19    pub(crate) distance_metric: DistanceMetric,
20}
21
22impl PgVectorStore {
23    /// Returns the distance operator used in SQL ORDER BY clauses.
24    fn distance_operator(&self) -> &'static str {
25        match self.distance_metric {
26            DistanceMetric::Cosine => "<=>",
27            DistanceMetric::L2 => "<->",
28            DistanceMetric::InnerProduct => "<#>",
29        }
30    }
31
32    /// Returns a reference to the underlying connection pool.
33    pub fn pool(&self) -> &Pool {
34        &self.pool
35    }
36
37    /// Returns the table name used by this store.
38    pub fn table(&self) -> &str {
39        &self.table
40    }
41
42    /// Returns the configured vector dimensions.
43    pub fn dimensions(&self) -> usize {
44        self.dimensions
45    }
46}
47
48impl VectorStore for PgVectorStore {
49    async fn upsert(&self, id: &str, embedding: Vec<f32>, document: Document) -> Result<()> {
50        if embedding.len() != self.dimensions {
51            return Err(DaimonError::Other(format!(
52                "embedding dimension mismatch: expected {}, got {}",
53                self.dimensions,
54                embedding.len()
55            )));
56        }
57
58        let client = self.pool.get().await.map_err(|e| {
59            DaimonError::Other(format!("pgvector pool error: {e}"))
60        })?;
61
62        let vec = Vector::from(embedding);
63        let metadata = serde_json::to_value(&document.metadata)
64            .map_err(|e| DaimonError::Other(format!("metadata serialization error: {e}")))?;
65
66        let sql = format!(
67            "INSERT INTO {} (id, embedding, content, metadata) VALUES ($1, $2, $3, $4) \
68             ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, \
69             content = EXCLUDED.content, metadata = EXCLUDED.metadata",
70            self.table
71        );
72
73        client
74            .execute(&sql as &str, &[&id, &vec, &document.content, &metadata])
75            .await
76            .map_err(|e| DaimonError::Other(format!("pgvector upsert error: {e}")))?;
77
78        Ok(())
79    }
80
81    async fn query(&self, embedding: Vec<f32>, top_k: usize) -> Result<Vec<ScoredDocument>> {
82        if embedding.len() != self.dimensions {
83            return Err(DaimonError::Other(format!(
84                "embedding dimension mismatch: expected {}, got {}",
85                self.dimensions,
86                embedding.len()
87            )));
88        }
89
90        let client = self.pool.get().await.map_err(|e| {
91            DaimonError::Other(format!("pgvector pool error: {e}"))
92        })?;
93
94        let vec = Vector::from(embedding);
95        let op = self.distance_operator();
96
97        // For cosine and L2, lower distance = more similar, so score = 1 - distance.
98        // For inner product, pgvector returns negative inner product, so score = -distance.
99        let score_expr = match self.distance_metric {
100            DistanceMetric::Cosine | DistanceMetric::L2 => {
101                format!("1.0 - (embedding {op} $1)")
102            }
103            DistanceMetric::InnerProduct => {
104                format!("-(embedding {op} $1)")
105            }
106        };
107
108        let sql = format!(
109            "SELECT id, content, metadata, {score_expr} AS score \
110             FROM {} ORDER BY embedding {op} $1 LIMIT $2",
111            self.table
112        );
113
114        let rows = client
115            .query(&sql as &str, &[&vec, &(top_k as i64)])
116            .await
117            .map_err(|e| DaimonError::Other(format!("pgvector query error: {e}")))?;
118
119        let mut results = Vec::with_capacity(rows.len());
120        for row in rows {
121            let content: String = row.get("content");
122            let metadata_val: serde_json::Value = row.get("metadata");
123            let score: f64 = row.get("score");
124
125            let metadata: HashMap<String, serde_json::Value> =
126                serde_json::from_value(metadata_val).unwrap_or_default();
127
128            let doc = Document {
129                content,
130                metadata,
131                score: Some(score),
132            };
133            results.push(ScoredDocument::new(doc, score));
134        }
135
136        Ok(results)
137    }
138
139    async fn delete(&self, id: &str) -> Result<bool> {
140        let client = self.pool.get().await.map_err(|e| {
141            DaimonError::Other(format!("pgvector pool error: {e}"))
142        })?;
143
144        let sql = format!("DELETE FROM {} WHERE id = $1", self.table);
145        let deleted = client
146            .execute(&sql as &str, &[&id])
147            .await
148            .map_err(|e| DaimonError::Other(format!("pgvector delete error: {e}")))?;
149
150        Ok(deleted > 0)
151    }
152
153    async fn count(&self) -> Result<usize> {
154        let client = self.pool.get().await.map_err(|e| {
155            DaimonError::Other(format!("pgvector pool error: {e}"))
156        })?;
157
158        let sql = format!("SELECT COUNT(*) AS cnt FROM {}", self.table);
159        let row = client
160            .query_one(&sql as &str, &[])
161            .await
162            .map_err(|e| DaimonError::Other(format!("pgvector count error: {e}")))?;
163
164        let count: i64 = row.get("cnt");
165        Ok(count as usize)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_distance_operator() {
175        let store = PgVectorStore {
176            pool: create_dummy_pool(),
177            table: "t".into(),
178            dimensions: 3,
179            distance_metric: DistanceMetric::Cosine,
180        };
181        assert_eq!(store.distance_operator(), "<=>");
182
183        let store = PgVectorStore {
184            dimensions: 3,
185            distance_metric: DistanceMetric::L2,
186            ..store
187        };
188        assert_eq!(store.distance_operator(), "<->");
189
190        let store = PgVectorStore {
191            distance_metric: DistanceMetric::InnerProduct,
192            ..store
193        };
194        assert_eq!(store.distance_operator(), "<#>");
195    }
196
197    fn create_dummy_pool() -> Pool {
198        let cfg = deadpool_postgres::Config {
199            host: Some("localhost".into()),
200            port: Some(5432),
201            dbname: Some("test".into()),
202            ..Default::default()
203        };
204        cfg.create_pool(None, tokio_postgres::NoTls).unwrap()
205    }
206}