1use crate::store::{MemoryEntry, SearchResult, VectorStore};
9use argentor_core::{ArgentorError, ArgentorResult};
10use async_trait::async_trait;
11use std::collections::HashMap;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15pub struct PgVectorStore {
17 #[allow(dead_code)]
19 connection_string: String,
20 #[allow(dead_code)]
22 table_name: String,
23 #[allow(dead_code)]
25 vector_column: String,
26 #[allow(dead_code)]
28 dimension: usize,
29 entries: RwLock<HashMap<Uuid, MemoryEntry>>,
31}
32
33impl PgVectorStore {
34 pub fn new(
39 connection_string: impl Into<String>,
40 table_name: impl Into<String>,
41 dimension: usize,
42 ) -> Self {
43 Self {
44 connection_string: connection_string.into(),
45 table_name: table_name.into(),
46 vector_column: "embedding".to_string(),
47 dimension,
48 entries: RwLock::new(HashMap::new()),
49 }
50 }
51
52 pub fn with_vector_column(mut self, column: impl Into<String>) -> Self {
54 self.vector_column = column.into();
55 self
56 }
57
58 pub fn table_name(&self) -> &str {
60 &self.table_name
61 }
62
63 pub fn vector_column(&self) -> &str {
65 &self.vector_column
66 }
67
68 pub fn dimension(&self) -> usize {
70 self.dimension
71 }
72
73 pub fn connection_string(&self) -> &str {
75 &self.connection_string
76 }
77
78 pub fn render_insert_sql(&self) -> String {
84 format!(
85 "INSERT INTO {} (id, content, {}, metadata, session_id, created_at) \
86 VALUES ($1, $2, $3::vector, $4, $5, $6)",
87 self.table_name, self.vector_column
88 )
89 }
90
91 pub fn render_search_sql(&self) -> String {
93 format!(
94 "SELECT id, content, {col}, metadata, session_id, created_at, \
95 1 - ({col} <=> $1::vector) AS score \
96 FROM {table} ORDER BY {col} <=> $1::vector LIMIT $2",
97 col = self.vector_column,
98 table = self.table_name
99 )
100 }
101}
102
103#[async_trait]
104impl VectorStore for PgVectorStore {
105 async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
106 if !entry.embedding.is_empty() && entry.embedding.len() != self.dimension {
107 return Err(ArgentorError::Agent(format!(
108 "pgvector: dim mismatch (got {}, expected {})",
109 entry.embedding.len(),
110 self.dimension
111 )));
112 }
113 let mut entries = self.entries.write().await;
114 entries.insert(entry.id, entry);
115 Ok(())
116 }
117
118 async fn search(
119 &self,
120 query_embedding: &[f32],
121 top_k: usize,
122 session_filter: Option<Uuid>,
123 ) -> ArgentorResult<Vec<SearchResult>> {
124 if query_embedding.is_empty() {
125 return Err(ArgentorError::Agent("Empty query embedding".to_string()));
126 }
127 let entries = self.entries.read().await;
128 let mut scored: Vec<SearchResult> = entries
129 .values()
130 .filter(|e| {
131 if let Some(sid) = session_filter {
132 e.session_id == Some(sid)
133 } else {
134 true
135 }
136 })
137 .map(|e| {
138 let score = cosine(query_embedding, &e.embedding);
139 SearchResult {
140 entry: e.clone(),
141 score,
142 }
143 })
144 .collect();
145 scored.sort_by(|a, b| {
146 b.score
147 .partial_cmp(&a.score)
148 .unwrap_or(std::cmp::Ordering::Equal)
149 });
150 scored.truncate(top_k);
151 Ok(scored)
152 }
153
154 async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
155 let mut entries = self.entries.write().await;
156 Ok(entries.remove(&id).is_some())
157 }
158
159 async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
160 let entries = self.entries.read().await;
161 Ok(entries
162 .values()
163 .filter(|e| {
164 if let Some(sid) = session_filter {
165 e.session_id == Some(sid)
166 } else {
167 true
168 }
169 })
170 .cloned()
171 .collect())
172 }
173
174 async fn count(&self) -> ArgentorResult<usize> {
175 let entries = self.entries.read().await;
176 Ok(entries.len())
177 }
178}
179
180fn cosine(a: &[f32], b: &[f32]) -> f32 {
181 if a.len() != b.len() {
182 return 0.0;
183 }
184 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
185 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
186 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
187 if na == 0.0 || nb == 0.0 {
188 0.0
189 } else {
190 dot / (na * nb)
191 }
192}
193
194#[cfg(test)]
195#[allow(clippy::unwrap_used, clippy::expect_used)]
196mod tests {
197 use super::*;
198 use chrono::Utc;
199
200 fn entry(content: &str, emb: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
201 MemoryEntry {
202 id: Uuid::new_v4(),
203 content: content.to_string(),
204 embedding: emb,
205 metadata: HashMap::new(),
206 session_id: session,
207 created_at: Utc::now(),
208 }
209 }
210
211 #[test]
212 fn test_new_defaults_vector_column() {
213 let s = PgVectorStore::new("postgres://u@h/d", "docs", 384);
214 assert_eq!(s.table_name(), "docs");
215 assert_eq!(s.vector_column(), "embedding");
216 assert_eq!(s.dimension(), 384);
217 assert_eq!(s.connection_string(), "postgres://u@h/d");
218 }
219
220 #[test]
221 fn test_with_vector_column() {
222 let s = PgVectorStore::new("postgres://u@h/d", "docs", 3).with_vector_column("vec");
223 assert_eq!(s.vector_column(), "vec");
224 }
225
226 #[test]
227 fn test_render_insert_sql() {
228 let s = PgVectorStore::new("postgres://u@h/d", "docs", 3);
229 let sql = s.render_insert_sql();
230 assert!(sql.contains("INSERT INTO docs"));
231 assert!(sql.contains("embedding"));
232 assert!(sql.contains("$3::vector"));
233 }
234
235 #[test]
236 fn test_render_search_sql_cosine_operator() {
237 let s = PgVectorStore::new("postgres://u@h/d", "docs", 3);
238 let sql = s.render_search_sql();
239 assert!(sql.contains("<=>"));
240 assert!(sql.contains("ORDER BY"));
241 assert!(sql.contains("LIMIT $2"));
242 }
243
244 #[test]
245 fn test_render_search_sql_uses_custom_column() {
246 let s = PgVectorStore::new("postgres://u@h/d", "docs", 3).with_vector_column("vec");
247 let sql = s.render_search_sql();
248 assert!(sql.contains("vec <=> $1::vector"));
249 }
250
251 #[tokio::test]
252 async fn test_insert_and_count() {
253 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
254 s.insert(entry("a", vec![1.0, 0.0], None)).await.unwrap();
255 assert_eq!(s.count().await.unwrap(), 1);
256 }
257
258 #[tokio::test]
259 async fn test_insert_rejects_bad_dim() {
260 let s = PgVectorStore::new("postgres://u@h/d", "t", 3);
261 let bad = entry("x", vec![1.0, 0.0], None);
262 assert!(s.insert(bad).await.is_err());
263 }
264
265 #[tokio::test]
266 async fn test_insert_allows_empty_embedding() {
267 let s = PgVectorStore::new("postgres://u@h/d", "t", 3);
268 let pending = entry("pending", vec![], None);
269 assert!(s.insert(pending).await.is_ok());
270 }
271
272 #[tokio::test]
273 async fn test_insert_many() {
274 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
275 for i in 0..15 {
276 s.insert(entry(&format!("e{i}"), vec![1.0, i as f32], None))
277 .await
278 .unwrap();
279 }
280 assert_eq!(s.count().await.unwrap(), 15);
281 }
282
283 #[tokio::test]
284 async fn test_search_orders_by_similarity() {
285 let s = PgVectorStore::new("postgres://u@h/d", "t", 3);
286 s.insert(entry("near", vec![0.9, 0.1, 0.0], None))
287 .await
288 .unwrap();
289 s.insert(entry("far", vec![0.0, 0.0, 1.0], None))
290 .await
291 .unwrap();
292 let r = s.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
293 assert_eq!(r[0].entry.content, "near");
294 }
295
296 #[tokio::test]
297 async fn test_search_top_k_limits() {
298 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
299 for i in 0..9 {
300 s.insert(entry(&format!("e{i}"), vec![1.0, i as f32], None))
301 .await
302 .unwrap();
303 }
304 let r = s.search(&[1.0, 0.0], 3, None).await.unwrap();
305 assert_eq!(r.len(), 3);
306 }
307
308 #[tokio::test]
309 async fn test_search_empty_query_errors() {
310 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
311 assert!(s.search(&[], 1, None).await.is_err());
312 }
313
314 #[tokio::test]
315 async fn test_search_session_filter() {
316 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
317 let sid = Uuid::new_v4();
318 s.insert(entry("s", vec![1.0, 0.0], Some(sid)))
319 .await
320 .unwrap();
321 s.insert(entry("o", vec![1.0, 0.0], None)).await.unwrap();
322 let r = s.search(&[1.0, 0.0], 5, Some(sid)).await.unwrap();
323 assert_eq!(r.len(), 1);
324 assert_eq!(r[0].entry.content, "s");
325 }
326
327 #[tokio::test]
328 async fn test_delete_existing() {
329 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
330 let e = entry("x", vec![1.0, 0.0], None);
331 let id = e.id;
332 s.insert(e).await.unwrap();
333 assert!(s.delete(id).await.unwrap());
334 }
335
336 #[tokio::test]
337 async fn test_delete_missing() {
338 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
339 assert!(!s.delete(Uuid::new_v4()).await.unwrap());
340 }
341
342 #[tokio::test]
343 async fn test_list_all_and_filtered() {
344 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
345 let sid = Uuid::new_v4();
346 s.insert(entry("a", vec![1.0, 0.0], Some(sid)))
347 .await
348 .unwrap();
349 s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
350 assert_eq!(s.list(None).await.unwrap().len(), 2);
351 assert_eq!(s.list(Some(sid)).await.unwrap().len(), 1);
352 }
353
354 #[tokio::test]
355 async fn test_metadata_preserved() {
356 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
357 let mut e = entry("m", vec![1.0, 0.0], None);
358 e.metadata
359 .insert("source".into(), serde_json::json!("manual"));
360 let id = e.id;
361 s.insert(e).await.unwrap();
362 let got = s
363 .list(None)
364 .await
365 .unwrap()
366 .into_iter()
367 .find(|x| x.id == id)
368 .unwrap();
369 assert_eq!(
370 got.metadata.get("source").unwrap(),
371 &serde_json::json!("manual")
372 );
373 }
374
375 #[tokio::test]
376 async fn test_count_after_deletes() {
377 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
378 let e = entry("a", vec![1.0, 0.0], None);
379 let id = e.id;
380 s.insert(e).await.unwrap();
381 s.insert(entry("b", vec![0.0, 1.0], None)).await.unwrap();
382 s.delete(id).await.unwrap();
383 assert_eq!(s.count().await.unwrap(), 1);
384 }
385
386 #[tokio::test]
387 async fn test_search_on_empty_store() {
388 let s = PgVectorStore::new("postgres://u@h/d", "t", 2);
389 assert!(s.search(&[1.0, 0.0], 5, None).await.unwrap().is_empty());
390 }
391}