Skip to main content

agentrs_memory/
vector.rs

1use std::{cmp::Ordering, sync::Arc};
2
3use async_trait::async_trait;
4use tokio::sync::RwLock;
5
6use agentrs_core::{Memory, Message, Result};
7
8use crate::{InMemoryMemory, SearchableMemory};
9
10/// Search result returned by a vector store.
11#[derive(Debug, Clone)]
12pub struct VectorSearchResult {
13    /// Similarity score.
14    pub score: f32,
15    /// Stored payload.
16    pub payload: Message,
17}
18
19/// Computes embeddings for messages.
20#[async_trait]
21pub trait Embedder: Send + Sync + 'static {
22    /// Generates an embedding vector.
23    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
24}
25
26/// Persists and searches embedding vectors.
27#[async_trait]
28pub trait VectorStore: Send + Sync + 'static {
29    /// Upserts a vector and payload.
30    async fn upsert(&self, id: String, vector: Vec<f32>, payload: Message) -> Result<()>;
31
32    /// Searches the store.
33    async fn search(&self, query: Vec<f32>, limit: usize) -> Result<Vec<VectorSearchResult>>;
34
35    /// Clears all stored vectors.
36    async fn clear(&self) -> Result<()>;
37}
38
39/// Small deterministic embedder useful for tests and local demos.
40#[derive(Debug, Clone, Copy, Default)]
41pub struct SimpleEmbedder;
42
43#[async_trait]
44impl Embedder for SimpleEmbedder {
45    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
46        let mut buckets = vec![0.0_f32; 16];
47        for (index, byte) in text.bytes().enumerate() {
48            buckets[index % 16] += f32::from(byte) / 255.0;
49        }
50        Ok(buckets)
51    }
52}
53
54/// In-memory vector store with cosine similarity search.
55#[derive(Default)]
56pub struct InMemoryVectorStore {
57    items: RwLock<Vec<(String, Vec<f32>, Message)>>,
58}
59
60impl InMemoryVectorStore {
61    /// Creates an empty store.
62    pub fn new() -> Self {
63        Self::default()
64    }
65}
66
67#[async_trait]
68impl VectorStore for InMemoryVectorStore {
69    async fn upsert(&self, id: String, vector: Vec<f32>, payload: Message) -> Result<()> {
70        let mut items = self.items.write().await;
71        if let Some(existing) = items.iter_mut().find(|item| item.0 == id) {
72            *existing = (id, vector, payload);
73        } else {
74            items.push((id, vector, payload));
75        }
76        Ok(())
77    }
78
79    async fn search(&self, query: Vec<f32>, limit: usize) -> Result<Vec<VectorSearchResult>> {
80        let items = self.items.read().await;
81        let mut scored = items
82            .iter()
83            .map(|(_, vector, payload)| VectorSearchResult {
84                score: cosine_similarity(vector, &query),
85                payload: payload.clone(),
86            })
87            .collect::<Vec<_>>();
88
89        scored.sort_by(|left, right| {
90            right
91                .score
92                .partial_cmp(&left.score)
93                .unwrap_or(Ordering::Equal)
94        });
95        scored.truncate(limit);
96        Ok(scored)
97    }
98
99    async fn clear(&self) -> Result<()> {
100        self.items.write().await.clear();
101        Ok(())
102    }
103}
104
105/// Memory backend that combines recent history with semantic retrieval.
106pub struct VectorMemory<E = SimpleEmbedder, S = InMemoryVectorStore> {
107    embedder: Arc<E>,
108    store: Arc<S>,
109    recent: InMemoryMemory,
110}
111
112impl VectorMemory<SimpleEmbedder, InMemoryVectorStore> {
113    /// Creates a vector memory with built-in components.
114    pub fn new() -> Self {
115        Self {
116            embedder: Arc::new(SimpleEmbedder),
117            store: Arc::new(InMemoryVectorStore::new()),
118            recent: InMemoryMemory::new(),
119        }
120    }
121}
122
123impl<E, S> VectorMemory<E, S>
124where
125    E: Embedder,
126    S: VectorStore,
127{
128    /// Creates a vector memory with custom embedder and store.
129    pub fn with_components(embedder: Arc<E>, store: Arc<S>) -> Self {
130        Self {
131            embedder,
132            store,
133            recent: InMemoryMemory::new(),
134        }
135    }
136}
137
138#[async_trait]
139impl<E, S> Memory for VectorMemory<E, S>
140where
141    E: Embedder,
142    S: VectorStore,
143{
144    async fn store(&mut self, key: &str, value: Message) -> Result<()> {
145        let vector = self.embedder.embed(&value.text_content()).await?;
146        self.store
147            .upsert(
148                format!("{key}-{}", uuid::Uuid::new_v4()),
149                vector,
150                value.clone(),
151            )
152            .await?;
153        self.recent.store(key, value).await
154    }
155
156    async fn retrieve(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
157        let vector = self.embedder.embed(query).await?;
158        Ok(self
159            .store
160            .search(vector, limit)
161            .await?
162            .into_iter()
163            .map(|result| result.payload)
164            .collect())
165    }
166
167    async fn history(&self) -> Result<Vec<Message>> {
168        self.recent.history().await
169    }
170
171    async fn clear(&mut self) -> Result<()> {
172        self.store.clear().await?;
173        self.recent.clear().await
174    }
175}
176
177#[async_trait]
178impl<E, S> SearchableMemory for VectorMemory<E, S>
179where
180    E: Embedder,
181    S: VectorStore,
182{
183    async fn token_count(&self) -> Result<usize> {
184        Ok(self
185            .recent
186            .history()
187            .await?
188            .into_iter()
189            .map(|message| message.text_content().chars().count() / 4)
190            .sum())
191    }
192
193    async fn search(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
194        self.retrieve(query, limit).await
195    }
196}
197
198fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
199    if left.len() != right.len() || left.is_empty() {
200        return 0.0;
201    }
202
203    let dot = left.iter().zip(right).map(|(l, r)| l * r).sum::<f32>();
204    let left_norm = left.iter().map(|value| value * value).sum::<f32>().sqrt();
205    let right_norm = right.iter().map(|value| value * value).sum::<f32>().sqrt();
206
207    if left_norm == 0.0 || right_norm == 0.0 {
208        0.0
209    } else {
210        dot / (left_norm * right_norm)
211    }
212}