Skip to main content

argentor_memory/
pinecone.rs

1//! Pinecone vector store adapter.
2//!
3//! Pinecone is a managed vector database. This adapter implements the
4//! [`VectorStore`] trait with a stub backend that stores entries in-memory
5//! using brute-force cosine similarity, suitable for testing and local
6//! development without external dependencies.
7//!
8//! Real HTTP calls to the Pinecone REST API are gated behind the
9//! `http-vectorstore` feature flag and are not wired by default.
10
11use crate::store::{MemoryEntry, SearchResult, VectorStore};
12use argentor_core::{ArgentorError, ArgentorResult};
13use async_trait::async_trait;
14use std::collections::HashMap;
15use tokio::sync::RwLock;
16use uuid::Uuid;
17
18/// Pinecone vector store adapter.
19///
20/// In stub mode (default), entries are stored in-memory and searched with
21/// brute-force cosine similarity. In HTTP mode (feature `http-vectorstore`),
22/// a [`reqwest::Client`] is constructed to talk to the Pinecone API.
23pub struct PineconeStore {
24    /// Pinecone API key.
25    #[allow(dead_code)]
26    api_key: String,
27    /// Pinecone index name.
28    #[allow(dead_code)]
29    index_name: String,
30    /// Pinecone environment (e.g., "us-east-1-aws").
31    #[allow(dead_code)]
32    environment: String,
33    /// Optional namespace for multi-tenant isolation.
34    #[allow(dead_code)]
35    namespace: Option<String>,
36    /// HTTP client — `None` in stub mode.
37    #[cfg(feature = "http-vectorstore")]
38    #[allow(dead_code)]
39    client: Option<reqwest::Client>,
40    /// Stub in-memory storage: id -> entry.
41    entries: RwLock<HashMap<Uuid, MemoryEntry>>,
42}
43
44impl PineconeStore {
45    /// Create a new Pinecone adapter in stub mode.
46    pub fn new(
47        api_key: impl Into<String>,
48        index_name: impl Into<String>,
49        environment: impl Into<String>,
50    ) -> Self {
51        Self {
52            api_key: api_key.into(),
53            index_name: index_name.into(),
54            environment: environment.into(),
55            namespace: None,
56            #[cfg(feature = "http-vectorstore")]
57            client: None,
58            entries: RwLock::new(HashMap::new()),
59        }
60    }
61
62    /// Set the namespace for multi-tenant isolation.
63    pub fn with_namespace(mut self, ns: impl Into<String>) -> Self {
64        self.namespace = Some(ns.into());
65        self
66    }
67
68    /// Return the configured index name.
69    pub fn index_name(&self) -> &str {
70        &self.index_name
71    }
72
73    /// Return the configured environment.
74    pub fn environment(&self) -> &str {
75        &self.environment
76    }
77
78    /// Return the configured namespace (if any).
79    pub fn namespace(&self) -> Option<&str> {
80        self.namespace.as_deref()
81    }
82
83    /// Enable real HTTP mode with a configured [`reqwest::Client`].
84    #[cfg(feature = "http-vectorstore")]
85    pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
86        self.client = Some(client);
87        self
88    }
89
90    /// Build the Pinecone upsert endpoint URL.
91    #[cfg(feature = "http-vectorstore")]
92    #[allow(dead_code)]
93    fn upsert_url(&self) -> String {
94        format!(
95            "https://{}-{}.svc.{}.pinecone.io/vectors/upsert",
96            self.index_name, "xxxxx", self.environment
97        )
98    }
99}
100
101#[async_trait]
102impl VectorStore for PineconeStore {
103    async fn insert(&self, entry: MemoryEntry) -> ArgentorResult<()> {
104        let mut entries = self.entries.write().await;
105        entries.insert(entry.id, entry);
106        Ok(())
107    }
108
109    async fn search(
110        &self,
111        query_embedding: &[f32],
112        top_k: usize,
113        session_filter: Option<Uuid>,
114    ) -> ArgentorResult<Vec<SearchResult>> {
115        if query_embedding.is_empty() {
116            return Err(ArgentorError::Agent("Empty query embedding".to_string()));
117        }
118        let entries = self.entries.read().await;
119        let mut scored: Vec<SearchResult> = entries
120            .values()
121            .filter(|e| {
122                if let Some(sid) = session_filter {
123                    e.session_id == Some(sid)
124                } else {
125                    true
126                }
127            })
128            .map(|e| {
129                let score = cosine(query_embedding, &e.embedding);
130                SearchResult {
131                    entry: e.clone(),
132                    score,
133                }
134            })
135            .collect();
136        scored.sort_by(|a, b| {
137            b.score
138                .partial_cmp(&a.score)
139                .unwrap_or(std::cmp::Ordering::Equal)
140        });
141        scored.truncate(top_k);
142        Ok(scored)
143    }
144
145    async fn delete(&self, id: Uuid) -> ArgentorResult<bool> {
146        let mut entries = self.entries.write().await;
147        Ok(entries.remove(&id).is_some())
148    }
149
150    async fn list(&self, session_filter: Option<Uuid>) -> ArgentorResult<Vec<MemoryEntry>> {
151        let entries = self.entries.read().await;
152        Ok(entries
153            .values()
154            .filter(|e| {
155                if let Some(sid) = session_filter {
156                    e.session_id == Some(sid)
157                } else {
158                    true
159                }
160            })
161            .cloned()
162            .collect())
163    }
164
165    async fn count(&self) -> ArgentorResult<usize> {
166        let entries = self.entries.read().await;
167        Ok(entries.len())
168    }
169}
170
171fn cosine(a: &[f32], b: &[f32]) -> f32 {
172    if a.len() != b.len() {
173        return 0.0;
174    }
175    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
176    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
177    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
178    if na == 0.0 || nb == 0.0 {
179        0.0
180    } else {
181        dot / (na * nb)
182    }
183}
184
185#[cfg(test)]
186#[allow(clippy::unwrap_used, clippy::expect_used)]
187mod tests {
188    use super::*;
189    use chrono::Utc;
190
191    fn entry(content: &str, emb: Vec<f32>, session: Option<Uuid>) -> MemoryEntry {
192        MemoryEntry {
193            id: Uuid::new_v4(),
194            content: content.to_string(),
195            embedding: emb,
196            metadata: HashMap::new(),
197            session_id: session,
198            created_at: Utc::now(),
199        }
200    }
201
202    #[test]
203    fn test_new_sets_fields() {
204        let store = PineconeStore::new("key-123", "my-index", "us-east-1-aws");
205        assert_eq!(store.api_key, "key-123");
206        assert_eq!(store.index_name(), "my-index");
207        assert_eq!(store.environment(), "us-east-1-aws");
208        assert!(store.namespace().is_none());
209    }
210
211    #[test]
212    fn test_with_namespace() {
213        let store = PineconeStore::new("k", "i", "e").with_namespace("tenant-a");
214        assert_eq!(store.namespace(), Some("tenant-a"));
215    }
216
217    #[test]
218    fn test_accepts_owned_strings() {
219        let store = PineconeStore::new(
220            String::from("k"),
221            String::from("i"),
222            String::from("us-west-2-aws"),
223        );
224        assert_eq!(store.environment(), "us-west-2-aws");
225    }
226
227    #[tokio::test]
228    async fn test_insert_increments_count() {
229        let store = PineconeStore::new("k", "i", "e");
230        assert_eq!(store.count().await.unwrap(), 0);
231        store
232            .insert(entry("hello", vec![1.0, 0.0, 0.0], None))
233            .await
234            .unwrap();
235        assert_eq!(store.count().await.unwrap(), 1);
236    }
237
238    #[tokio::test]
239    async fn test_insert_many() {
240        let store = PineconeStore::new("k", "i", "e");
241        for i in 0..25 {
242            store
243                .insert(entry(&format!("e{i}"), vec![i as f32, 0.0], None))
244                .await
245                .unwrap();
246        }
247        assert_eq!(store.count().await.unwrap(), 25);
248    }
249
250    #[tokio::test]
251    async fn test_search_orders_by_similarity() {
252        let store = PineconeStore::new("k", "i", "e");
253        store
254            .insert(entry("close", vec![0.9, 0.1, 0.0], None))
255            .await
256            .unwrap();
257        store
258            .insert(entry("far", vec![0.0, 0.0, 1.0], None))
259            .await
260            .unwrap();
261        let results = store.search(&[1.0, 0.0, 0.0], 2, None).await.unwrap();
262        assert_eq!(results.len(), 2);
263        assert_eq!(results[0].entry.content, "close");
264        assert!(results[0].score > results[1].score);
265    }
266
267    #[tokio::test]
268    async fn test_search_respects_top_k() {
269        let store = PineconeStore::new("k", "i", "e");
270        for i in 0..10 {
271            store
272                .insert(entry(&format!("e{i}"), vec![1.0, i as f32 / 10.0], None))
273                .await
274                .unwrap();
275        }
276        let results = store.search(&[1.0, 0.0], 3, None).await.unwrap();
277        assert_eq!(results.len(), 3);
278    }
279
280    #[tokio::test]
281    async fn test_search_empty_embedding_errors() {
282        let store = PineconeStore::new("k", "i", "e");
283        assert!(store.search(&[], 5, None).await.is_err());
284    }
285
286    #[tokio::test]
287    async fn test_search_session_filter() {
288        let store = PineconeStore::new("k", "i", "e");
289        let sid = Uuid::new_v4();
290        store
291            .insert(entry("a", vec![1.0, 0.0], Some(sid)))
292            .await
293            .unwrap();
294        store
295            .insert(entry("b", vec![1.0, 0.0], None))
296            .await
297            .unwrap();
298        let results = store.search(&[1.0, 0.0], 10, Some(sid)).await.unwrap();
299        assert_eq!(results.len(), 1);
300        assert_eq!(results[0].entry.content, "a");
301    }
302
303    #[tokio::test]
304    async fn test_delete_existing() {
305        let store = PineconeStore::new("k", "i", "e");
306        let e = entry("to-delete", vec![1.0], None);
307        let id = e.id;
308        store.insert(e).await.unwrap();
309        assert!(store.delete(id).await.unwrap());
310        assert_eq!(store.count().await.unwrap(), 0);
311    }
312
313    #[tokio::test]
314    async fn test_delete_missing_returns_false() {
315        let store = PineconeStore::new("k", "i", "e");
316        assert!(!store.delete(Uuid::new_v4()).await.unwrap());
317    }
318
319    #[tokio::test]
320    async fn test_list_all() {
321        let store = PineconeStore::new("k", "i", "e");
322        store.insert(entry("a", vec![1.0], None)).await.unwrap();
323        store.insert(entry("b", vec![0.5], None)).await.unwrap();
324        let all = store.list(None).await.unwrap();
325        assert_eq!(all.len(), 2);
326    }
327
328    #[tokio::test]
329    async fn test_list_filtered_by_session() {
330        let store = PineconeStore::new("k", "i", "e");
331        let sid = Uuid::new_v4();
332        store
333            .insert(entry("a", vec![1.0], Some(sid)))
334            .await
335            .unwrap();
336        store.insert(entry("b", vec![0.5], None)).await.unwrap();
337        let filtered = store.list(Some(sid)).await.unwrap();
338        assert_eq!(filtered.len(), 1);
339        assert_eq!(filtered[0].content, "a");
340    }
341
342    #[tokio::test]
343    async fn test_namespace_isolation_does_not_cross_instances() {
344        let a = PineconeStore::new("k", "i", "e").with_namespace("ns-a");
345        let b = PineconeStore::new("k", "i", "e").with_namespace("ns-b");
346        a.insert(entry("x", vec![1.0], None)).await.unwrap();
347        assert_eq!(a.count().await.unwrap(), 1);
348        assert_eq!(b.count().await.unwrap(), 0);
349    }
350
351    #[tokio::test]
352    async fn test_insert_preserves_metadata() {
353        let store = PineconeStore::new("k", "i", "e");
354        let mut e = entry("m", vec![1.0, 0.0], None);
355        e.metadata
356            .insert("tag".to_string(), serde_json::json!("important"));
357        let id = e.id;
358        store.insert(e).await.unwrap();
359        let all = store.list(None).await.unwrap();
360        let got = all.iter().find(|x| x.id == id).unwrap();
361        assert_eq!(got.metadata.get("tag").unwrap(), &serde_json::json!("important"));
362    }
363
364    #[tokio::test]
365    async fn test_search_returns_empty_when_store_empty() {
366        let store = PineconeStore::new("k", "i", "e");
367        let results = store.search(&[1.0, 0.0], 5, None).await.unwrap();
368        assert!(results.is_empty());
369    }
370
371    #[tokio::test]
372    async fn test_count_after_deletes() {
373        let store = PineconeStore::new("k", "i", "e");
374        let e1 = entry("a", vec![1.0], None);
375        let e2 = entry("b", vec![0.5], None);
376        let id1 = e1.id;
377        store.insert(e1).await.unwrap();
378        store.insert(e2).await.unwrap();
379        store.delete(id1).await.unwrap();
380        assert_eq!(store.count().await.unwrap(), 1);
381    }
382}