Skip to main content

argentor_memory/
pgvector.rs

1//! pgvector (PostgreSQL extension) vector store adapter.
2//!
3//! pgvector adds a `vector` column type to PostgreSQL and supports
4//! L2/cosine/inner-product distance queries. This adapter implements the
5//! [`VectorStore`] trait with a stub backend; real SQL calls would require
6//! a Postgres driver and are not wired here.
7
8use crate::store::{MemoryEntry, SearchResult, VectorStore};
9use argentor_core::{ArgentorError, ArgentorResult};
10use async_trait::async_trait;
11use std::collections::HashMap;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15/// pgvector adapter for PostgreSQL with the `vector` extension.
16pub struct PgVectorStore {
17    /// PostgreSQL connection string (e.g., "postgres://user:pass@host/db").
18    #[allow(dead_code)]
19    connection_string: String,
20    /// Table that stores the vectors.
21    #[allow(dead_code)]
22    table_name: String,
23    /// Column holding the `vector` value (defaults to "embedding").
24    #[allow(dead_code)]
25    vector_column: String,
26    /// Declared dimension (pgvector requires fixed dim per column).
27    #[allow(dead_code)]
28    dimension: usize,
29    /// Stub in-memory storage.
30    entries: RwLock<HashMap<Uuid, MemoryEntry>>,
31}
32
33impl PgVectorStore {
34    /// Create a new pgvector adapter in stub mode.
35    ///
36    /// The vector column defaults to `"embedding"`. Use
37    /// [`Self::with_vector_column`] to override.
38    pub fn new(
39        connection_string: impl Into<String>,
40        table_name: impl Into<String>,
41        dimension: usize,
42    ) -> Self {
43        Self {
44            connection_string: connection_string.into(),
45            table_name: table_name.into(),
46            vector_column: "embedding".to_string(),
47            dimension,
48            entries: RwLock::new(HashMap::new()),
49        }
50    }
51
52    /// Override the vector column name (default `"embedding"`).
53    pub fn with_vector_column(mut self, column: impl Into<String>) -> Self {
54        self.vector_column = column.into();
55        self
56    }
57
58    /// Return the configured table name.
59    pub fn table_name(&self) -> &str {
60        &self.table_name
61    }
62
63    /// Return the configured vector column name.
64    pub fn vector_column(&self) -> &str {
65        &self.vector_column
66    }
67
68    /// Return the configured dimension.
69    pub fn dimension(&self) -> usize {
70        self.dimension
71    }
72
73    /// Return the configured connection string.
74    pub fn connection_string(&self) -> &str {
75        &self.connection_string
76    }
77
78    /// Render a stub SQL `INSERT` statement for documentation / debug.
79    ///
80    /// This does NOT execute anything — it is a helper that makes the
81    /// underlying SQL shape visible for tests and for users planning a
82    /// real driver integration.
83    pub fn render_insert_sql(&self) -> String {
84        format!(
85            "INSERT INTO {} (id, content, {}, metadata, session_id, created_at) \
86             VALUES ($1, $2, $3::vector, $4, $5, $6)",
87            self.table_name, self.vector_column
88        )
89    }
90
91    /// Render a stub SQL `SELECT` statement for cosine similarity search.
92    pub fn render_search_sql(&self) -> String {
93        format!(
94            "SELECT id, content, {col}, metadata, session_id, created_at, \
95             1 - ({col} <=> $1::vector) AS score \
96             FROM {table} ORDER BY {col} <=> $1::vector LIMIT $2",
97            col = self.vector_column,
98            table = self.table_name
99        )
100    }
101}
102
103#[async_trait]
104impl VectorStore for PgVectorStore {
105    async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
106        if !entry.embedding.is_empty() && entry.embedding.len() != self.dimension {
107            return Err(ArgentorError::Agent(format!(
108                "pgvector: dim mismatch (got {}, expected {})",
109                entry.embedding.len(),
110                self.dimension
111            )));
112        }
113        let mut entries = self.entries.write().await;
114        entries.insert(entry.id, entry);
115        Ok(())
116    }
117
118    async fn search(
119        &self,
120        query_embedding: &[f32],
121        top_k: usize,
122        session_filter: Option<Uuid>,
123    ) -> ArgentorResult<Vec<SearchResult>> {
124        if query_embedding.is_empty() {
125            return Err(ArgentorError::Agent("Empty query embedding".to_string()));
126        }
127        let entries = self.entries.read().await;
128        let mut scored: Vec<SearchResult> = entries
129            .values()
130            .filter(|e| {
131                if let Some(sid) = session_filter {
132                    e.session_id == Some(sid)
133                } else {
134                    true
135                }
136            })
137            .map(|e| {
138                let score = cosine(query_embedding, &e.embedding);
139                SearchResult {
140                    entry: e.clone(),
141                    score,
142                }
143            })
144            .collect();
145        scored.sort_by(|a, b| {
146            b.score
147                .partial_cmp(&a.score)
148                .unwrap_or(std::cmp::Ordering::Equal)
149        });
150        scored.truncate(top_k);
151        Ok(scored)
152    }
153
154    async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
155        let mut entries = self.entries.write().await;
156        Ok(entries.remove(&id).is_some())
157    }
158
159    async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
160        let entries = self.entries.read().await;
161        Ok(entries
162            .values()
163            .filter(|e| {
164                if let Some(sid) = session_filter {
165                    e.session_id == Some(sid)
166                } else {
167                    true
168                }
169            })
170            .cloned()
171            .collect())
172    }
173
174    async fn count(&self) -> ArgentorResult<usize> {
175        let entries = self.entries.read().await;
176        Ok(entries.len())
177    }
178}
179
180fn cosine(a: &[f32], b: &[f32]) -> f32 {
181    if a.len() != b.len() {
182        return 0.0;
183    }
184    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
185    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
186    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
187    if na == 0.0 || nb == 0.0 {
188        0.0
189    } else {
190        dot / (na * nb)
191    }
192}
193
194#[cfg(test)]
195#[allow(clippy::unwrap_used, clippy::expect_used)]
196mod tests {
197    use super::*;
198    use chrono::Utc;
199
200    fn entry(content: &str, emb: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
201        MemoryEntry {
202            id: Uuid::new_v4(),
203            content: content.to_string(),
204            embedding: emb,
205            metadata: HashMap::new(),
206            session_id: session,
207            created_at: Utc::now(),
208        }
209    }
210
211    #[test]
212    fn test_new_defaults_vector_column() {
213        let s = PgVectorStore::new("postgres://u@h/d", "docs", 384);
214        assert_eq!(s.table_name(), "docs");
215        assert_eq!(s.vector_column(), "embedding");
216        assert_eq!(s.dimension(), 384);
217        assert_eq!(s.connection_string(), "postgres://u@h/d");
218    }
219
220    #[test]
221    fn test_with_vector_column() {
222        let s = PgVectorStore::new("postgres://u@h/d", "docs", 3).with_vector_column("vec");
223        assert_eq!(s.vector_column(), "vec");
224    }
225
226    #[test]
227    fn test_render_insert_sql() {
228        let s = PgVectorStore::new("postgres://u@h/d", "docs", 3);
229        let sql = s.render_insert_sql();
230        assert!(sql.contains("INSERT INTO docs"));
231        assert!(sql.contains("embedding"));
232        assert!(sql.contains("$3::vector"));
233    }
234
235    #[test]
236    fn test_render_search_sql_cosine_operator() {
237        let s = PgVectorStore::new("postgres://u@h/d", "docs", 3);
238        let sql = s.render_search_sql();
239        assert!(sql.contains("<=>"));
240        assert!(sql.contains("ORDER BY"));
241        assert!(sql.contains("LIMIT $2"));
242    }
243
244    #[test]
245    fn test_render_search_sql_uses_custom_column() {
246        let s = PgVectorStore::new("postgres://u@h/d", "docs", 3).with_vector_column("vec");
247        let sql = s.render_search_sql();
248        assert!(sql.contains("vec <=> $1::vector"));
249    }
250
251    #[tokio::test]
252    async fn test_insert_and_count() {
253        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
254        s.insert(entry("a", vec![1.0, 0.0], None)).await.unwrap();
255        assert_eq!(s.count().await.unwrap(), 1);
256    }
257
258    #[tokio::test]
259    async fn test_insert_rejects_bad_dim() {
260        let s = PgVectorStore::new("postgres://u@h/d", "t", 3);
261        let bad = entry("x", vec![1.0, 0.0], None);
262        assert!(s.insert(bad).await.is_err());
263    }
264
265    #[tokio::test]
266    async fn test_insert_allows_empty_embedding() {
267        let s = PgVectorStore::new("postgres://u@h/d", "t", 3);
268        let pending = entry("pending", vec![], None);
269        assert!(s.insert(pending).await.is_ok());
270    }
271
272    #[tokio::test]
273    async fn test_insert_many() {
274        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
275        for i in 0..15 {
276            s.insert(entry(&format!("e{i}"), vec![1.0, i as f32], None))
277                .await
278                .unwrap();
279        }
280        assert_eq!(s.count().await.unwrap(), 15);
281    }
282
283    #[tokio::test]
284    async fn test_search_orders_by_similarity() {
285        let s = PgVectorStore::new("postgres://u@h/d", "t", 3);
286        s.insert(entry("near", vec![0.9, 0.1, 0.0], None))
287            .await
288            .unwrap();
289        s.insert(entry("far", vec![0.0, 0.0, 1.0], None))
290            .await
291            .unwrap();
292        let r = s.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
293        assert_eq!(r[0].entry.content, "near");
294    }
295
296    #[tokio::test]
297    async fn test_search_top_k_limits() {
298        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
299        for i in 0..9 {
300            s.insert(entry(&format!("e{i}"), vec![1.0, i as f32], None))
301                .await
302                .unwrap();
303        }
304        let r = s.search(&[1.0, 0.0], 3, None).await.unwrap();
305        assert_eq!(r.len(), 3);
306    }
307
308    #[tokio::test]
309    async fn test_search_empty_query_errors() {
310        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
311        assert!(s.search(&[], 1, None).await.is_err());
312    }
313
314    #[tokio::test]
315    async fn test_search_session_filter() {
316        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
317        let sid = Uuid::new_v4();
318        s.insert(entry("s", vec![1.0, 0.0], Some(sid)))
319            .await
320            .unwrap();
321        s.insert(entry("o", vec![1.0, 0.0], None)).await.unwrap();
322        let r = s.search(&[1.0, 0.0], 5, Some(sid)).await.unwrap();
323        assert_eq!(r.len(), 1);
324        assert_eq!(r[0].entry.content, "s");
325    }
326
327    #[tokio::test]
328    async fn test_delete_existing() {
329        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
330        let e = entry("x", vec![1.0, 0.0], None);
331        let id = e.id;
332        s.insert(e).await.unwrap();
333        assert!(s.delete(id).await.unwrap());
334    }
335
336    #[tokio::test]
337    async fn test_delete_missing() {
338        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
339        assert!(!s.delete(Uuid::new_v4()).await.unwrap());
340    }
341
342    #[tokio::test]
343    async fn test_list_all_and_filtered() {
344        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
345        let sid = Uuid::new_v4();
346        s.insert(entry("a", vec![1.0, 0.0], Some(sid)))
347            .await
348            .unwrap();
349        s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
350        assert_eq!(s.list(None).await.unwrap().len(), 2);
351        assert_eq!(s.list(Some(sid)).await.unwrap().len(), 1);
352    }
353
354    #[tokio::test]
355    async fn test_metadata_preserved() {
356        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
357        let mut e = entry("m", vec![1.0, 0.0], None);
358        e.metadata
359            .insert("source".into(), serde_json::json!("manual"));
360        let id = e.id;
361        s.insert(e).await.unwrap();
362        let got = s
363            .list(None)
364            .await
365            .unwrap()
366            .into_iter()
367            .find(|x| x.id == id)
368            .unwrap();
369        assert_eq!(
370            got.metadata.get("source").unwrap(),
371            &serde_json::json!("manual")
372        );
373    }
374
375    #[tokio::test]
376    async fn test_count_after_deletes() {
377        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
378        let e = entry("a", vec![1.0, 0.0], None);
379        let id = e.id;
380        s.insert(e).await.unwrap();
381        s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
382        s.delete(id).await.unwrap();
383        assert_eq!(s.count().await.unwrap(), 1);
384    }
385
386    #[tokio::test]
387    async fn test_search_on_empty_store() {
388        let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
389        assert!(s.search(&[1.0, 0.0], 5, None).await.unwrap().is_empty());
390    }
391}