1use crate::store::{MemoryEntry, SearchResult, VectorStore};
12use argentor_core::{ArgentorError, ArgentorResult};
13use async_trait::async_trait;
14use std::collections::HashMap;
15use tokio::sync::RwLock;
16use uuid::Uuid;
17
18pub struct PineconeStore {
24 #[allow(dead_code)]
26 api_key: String,
27 #[allow(dead_code)]
29 index_name: String,
30 #[allow(dead_code)]
32 environment: String,
33 #[allow(dead_code)]
35 namespace: Option<String>,
36 #[cfg(feature = "http-vectorstore")]
38 #[allow(dead_code)]
39 client: Option<reqwest::Client>,
40 entries: RwLock<HashMap<Uuid, MemoryEntry>>,
42}
43
44impl PineconeStore {
45 pub fn new(
47 api_key: impl Into<String>,
48 index_name: impl Into<String>,
49 environment: impl Into<String>,
50 ) -> Self {
51 Self {
52 api_key: api_key.into(),
53 index_name: index_name.into(),
54 environment: environment.into(),
55 namespace: None,
56 #[cfg(feature = "http-vectorstore")]
57 client: None,
58 entries: RwLock::new(HashMap::new()),
59 }
60 }
61
62 pub fn with_namespace(mut self, ns: impl Into<String>) -> Self {
64 self.namespace = Some(ns.into());
65 self
66 }
67
68 pub fn index_name(&self) -> &str {
70 &self.index_name
71 }
72
73 pub fn environment(&self) -> &str {
75 &self.environment
76 }
77
78 pub fn namespace(&self) -> Option<&str> {
80 self.namespace.as_deref()
81 }
82
83 #[cfg(feature = "http-vectorstore")]
85 pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
86 self.client = Some(client);
87 self
88 }
89
90 #[cfg(feature = "http-vectorstore")]
92 #[allow(dead_code)]
93 fn upsert_url(&self) -> String {
94 format!(
95 "https://{}-{}.svc.{}.pinecone.io/vectors/upsert",
96 self.index_name, "xxxxx", self.environment
97 )
98 }
99}
100
101#[async_trait]
102impl VectorStore for PineconeStore {
103 async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
104 let mut entries = self.entries.write().await;
105 entries.insert(entry.id, entry);
106 Ok(())
107 }
108
109 async fn search(
110 &self,
111 query_embedding: &[f32],
112 top_k: usize,
113 session_filter: Option<Uuid>,
114 ) -> ArgentorResult<Vec<SearchResult>> {
115 if query_embedding.is_empty() {
116 return Err(ArgentorError::Agent("Empty query embedding".to_string()));
117 }
118 let entries = self.entries.read().await;
119 let mut scored: Vec<SearchResult> = entries
120 .values()
121 .filter(|e| {
122 if let Some(sid) = session_filter {
123 e.session_id == Some(sid)
124 } else {
125 true
126 }
127 })
128 .map(|e| {
129 let score = cosine(query_embedding, &e.embedding);
130 SearchResult {
131 entry: e.clone(),
132 score,
133 }
134 })
135 .collect();
136 scored.sort_by(|a, b| {
137 b.score
138 .partial_cmp(&a.score)
139 .unwrap_or(std::cmp::Ordering::Equal)
140 });
141 scored.truncate(top_k);
142 Ok(scored)
143 }
144
145 async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
146 let mut entries = self.entries.write().await;
147 Ok(entries.remove(&id).is_some())
148 }
149
150 async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
151 let entries = self.entries.read().await;
152 Ok(entries
153 .values()
154 .filter(|e| {
155 if let Some(sid) = session_filter {
156 e.session_id == Some(sid)
157 } else {
158 true
159 }
160 })
161 .cloned()
162 .collect())
163 }
164
165 async fn count(&self) -> ArgentorResult<usize> {
166 let entries = self.entries.read().await;
167 Ok(entries.len())
168 }
169}
170
171fn cosine(a: &[f32], b: &[f32]) -> f32 {
172 if a.len() != b.len() {
173 return 0.0;
174 }
175 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
176 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
177 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
178 if na == 0.0 || nb == 0.0 {
179 0.0
180 } else {
181 dot / (na * nb)
182 }
183}
184
185#[cfg(test)]
186#[allow(clippy::unwrap_used, clippy::expect_used)]
187mod tests {
188 use super::*;
189 use chrono::Utc;
190
191 fn entry(content: &str, emb: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
192 MemoryEntry {
193 id: Uuid::new_v4(),
194 content: content.to_string(),
195 embedding: emb,
196 metadata: HashMap::new(),
197 session_id: session,
198 created_at: Utc::now(),
199 }
200 }
201
202 #[test]
203 fn test_new_sets_fields() {
204 let store = PineconeStore::new("key-123", "my-index", "us-east-1-aws");
205 assert_eq!(store.api_key, "key-123");
206 assert_eq!(store.index_name(), "my-index");
207 assert_eq!(store.environment(), "us-east-1-aws");
208 assert!(store.namespace().is_none());
209 }
210
211 #[test]
212 fn test_with_namespace() {
213 let store = PineconeStore::new("k", "i", "e").with_namespace("tenant-a");
214 assert_eq!(store.namespace(), Some("tenant-a"));
215 }
216
217 #[test]
218 fn test_accepts_owned_strings() {
219 let store = PineconeStore::new(
220 String::from("k"),
221 String::from("i"),
222 String::from("us-west-2-aws"),
223 );
224 assert_eq!(store.environment(), "us-west-2-aws");
225 }
226
227 #[tokio::test]
228 async fn test_insert_increments_count() {
229 let store = PineconeStore::new("k", "i", "e");
230 assert_eq!(store.count().await.unwrap(), 0);
231 store
232 .insert(entry("hello", vec![1.0, 0.0, 0.0], None))
233 .await
234 .unwrap();
235 assert_eq!(store.count().await.unwrap(), 1);
236 }
237
238 #[tokio::test]
239 async fn test_insert_many() {
240 let store = PineconeStore::new("k", "i", "e");
241 for i in 0..25 {
242 store
243 .insert(entry(&format!("e{i}"), vec![i as f32, 0.0], None))
244 .await
245 .unwrap();
246 }
247 assert_eq!(store.count().await.unwrap(), 25);
248 }
249
250 #[tokio::test]
251 async fn test_search_orders_by_similarity() {
252 let store = PineconeStore::new("k", "i", "e");
253 store
254 .insert(entry("close", vec![0.9, 0.1, 0.0], None))
255 .await
256 .unwrap();
257 store
258 .insert(entry("far", vec![0.0, 0.0, 1.0], None))
259 .await
260 .unwrap();
261 let results = store.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
262 assert_eq!(results.len(), 2);
263 assert_eq!(results[0].entry.content, "close");
264 assert!(results[0].score > results[1].score);
265 }
266
267 #[tokio::test]
268 async fn test_search_respects_top_k() {
269 let store = PineconeStore::new("k", "i", "e");
270 for i in 0..10 {
271 store
272 .insert(entry(&format!("e{i}"), vec![1.0, i as f32 / 10.0], None))
273 .await
274 .unwrap();
275 }
276 let results = store.search(&[1.0, 0.0], 3, None).await.unwrap();
277 assert_eq!(results.len(), 3);
278 }
279
280 #[tokio::test]
281 async fn test_search_empty_embedding_errors() {
282 let store = PineconeStore::new("k", "i", "e");
283 assert!(store.search(&[], 5, None).await.is_err());
284 }
285
286 #[tokio::test]
287 async fn test_search_session_filter() {
288 let store = PineconeStore::new("k", "i", "e");
289 let sid = Uuid::new_v4();
290 store
291 .insert(entry("a", vec![1.0, 0.0], Some(sid)))
292 .await
293 .unwrap();
294 store
295 .insert(entry("b", vec![1.0, 0.0], None))
296 .await
297 .unwrap();
298 let results = store.search(&[1.0, 0.0], 10, Some(sid)).await.unwrap();
299 assert_eq!(results.len(), 1);
300 assert_eq!(results[0].entry.content, "a");
301 }
302
303 #[tokio::test]
304 async fn test_delete_existing() {
305 let store = PineconeStore::new("k", "i", "e");
306 let e = entry("to-delete", vec![1.0], None);
307 let id = e.id;
308 store.insert(e).await.unwrap();
309 assert!(store.delete(id).await.unwrap());
310 assert_eq!(store.count().await.unwrap(), 0);
311 }
312
313 #[tokio::test]
314 async fn test_delete_missing_returns_false() {
315 let store = PineconeStore::new("k", "i", "e");
316 assert!(!store.delete(Uuid::new_v4()).await.unwrap());
317 }
318
319 #[tokio::test]
320 async fn test_list_all() {
321 let store = PineconeStore::new("k", "i", "e");
322 store.insert(entry("a", vec![1.0], None)).await.unwrap();
323 store.insert(entry("b", vec![0.5], None)).await.unwrap();
324 let all = store.list(None).await.unwrap();
325 assert_eq!(all.len(), 2);
326 }
327
328 #[tokio::test]
329 async fn test_list_filtered_by_session() {
330 let store = PineconeStore::new("k", "i", "e");
331 let sid = Uuid::new_v4();
332 store
333 .insert(entry("a", vec![1.0], Some(sid)))
334 .await
335 .unwrap();
336 store.insert(entry("b", vec![0.5], None)).await.unwrap();
337 let filtered = store.list(Some(sid)).await.unwrap();
338 assert_eq!(filtered.len(), 1);
339 assert_eq!(filtered[0].content, "a");
340 }
341
342 #[tokio::test]
343 async fn test_namespace_isolation_does_not_cross_instances() {
344 let a = PineconeStore::new("k", "i", "e").with_namespace("ns-a");
345 let b = PineconeStore::new("k", "i", "e").with_namespace("ns-b");
346 a.insert(entry("x", vec![1.0], None)).await.unwrap();
347 assert_eq!(a.count().await.unwrap(), 1);
348 assert_eq!(b.count().await.unwrap(), 0);
349 }
350
351 #[tokio::test]
352 async fn test_insert_preserves_metadata() {
353 let store = PineconeStore::new("k", "i", "e");
354 let mut e = entry("m", vec![1.0, 0.0], None);
355 e.metadata
356 .insert("tag".to_string(), serde_json::json!("important"));
357 let id = e.id;
358 store.insert(e).await.unwrap();
359 let all = store.list(None).await.unwrap();
360 let got = all.iter().find(|x| x.id == id).unwrap();
361 assert_eq!(
362 got.metadata.get("tag").unwrap(),
363 &serde_json::json!("important")
364 );
365 }
366
367 #[tokio::test]
368 async fn test_search_returns_empty_when_store_empty() {
369 let store = PineconeStore::new("k", "i", "e");
370 let results = store.search(&[1.0, 0.0], 5, None).await.unwrap();
371 assert!(results.is_empty());
372 }
373
374 #[tokio::test]
375 async fn test_count_after_deletes() {
376 let store = PineconeStore::new("k", "i", "e");
377 let e1 = entry("a", vec![1.0], None);
378 let e2 = entry("b", vec![0.5], None);
379 let id1 = e1.id;
380 store.insert(e1).await.unwrap();
381 store.insert(e2).await.unwrap();
382 store.delete(id1).await.unwrap();
383 assert_eq!(store.count().await.unwrap(), 1);
384 }
385}