1use std::{cmp::Ordering, sync::Arc};
2
3use async_trait::async_trait;
4use tokio::sync::RwLock;
5
6use agentrs_core::{Memory, Message, Result};
7
8use crate::{InMemoryMemory, SearchableMemory};
9
10#[derive(Debug, Clone)]
12pub struct VectorSearchResult {
13 pub score: f32,
15 pub payload: Message,
17}
18
19#[async_trait]
21pub trait Embedder: Send + Sync + 'static {
22 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
24}
25
26#[async_trait]
28pub trait VectorStore: Send + Sync + 'static {
29 async fn upsert(&self, id: String, vector: Vec<f32>, payload: Message) -> Result<()>;
31
32 async fn search(&self, query: Vec<f32>, limit: usize) -> Result<Vec<VectorSearchResult>>;
34
35 async fn clear(&self) -> Result<()>;
37}
38
39#[derive(Debug, Clone, Copy, Default)]
41pub struct SimpleEmbedder;
42
43#[async_trait]
44impl Embedder for SimpleEmbedder {
45 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
46 let mut buckets = vec![0.0_f32; 16];
47 for (index, byte) in text.bytes().enumerate() {
48 buckets[index % 16] += f32::from(byte) / 255.0;
49 }
50 Ok(buckets)
51 }
52}
53
54#[derive(Default)]
56pub struct InMemoryVectorStore {
57 items: RwLock<Vec<(String, Vec<f32>, Message)>>,
58}
59
60impl InMemoryVectorStore {
61 pub fn new() -> Self {
63 Self::default()
64 }
65}
66
67#[async_trait]
68impl VectorStore for InMemoryVectorStore {
69 async fn upsert(&self, id: String, vector: Vec<f32>, payload: Message) -> Result<()> {
70 let mut items = self.items.write().await;
71 if let Some(existing) = items.iter_mut().find(|item| item.0 == id) {
72 *existing = (id, vector, payload);
73 } else {
74 items.push((id, vector, payload));
75 }
76 Ok(())
77 }
78
79 async fn search(&self, query: Vec<f32>, limit: usize) -> Result<Vec<VectorSearchResult>> {
80 let items = self.items.read().await;
81 let mut scored = items
82 .iter()
83 .map(|(_, vector, payload)| VectorSearchResult {
84 score: cosine_similarity(vector, &query),
85 payload: payload.clone(),
86 })
87 .collect::<Vec<_>>();
88
89 scored.sort_by(|left, right| {
90 right
91 .score
92 .partial_cmp(&left.score)
93 .unwrap_or(Ordering::Equal)
94 });
95 scored.truncate(limit);
96 Ok(scored)
97 }
98
99 async fn clear(&self) -> Result<()> {
100 self.items.write().await.clear();
101 Ok(())
102 }
103}
104
105pub struct VectorMemory<E = SimpleEmbedder, S = InMemoryVectorStore> {
107 embedder: Arc<E>,
108 store: Arc<S>,
109 recent: InMemoryMemory,
110}
111
112impl VectorMemory<SimpleEmbedder, InMemoryVectorStore> {
113 pub fn new() -> Self {
115 Self {
116 embedder: Arc::new(SimpleEmbedder),
117 store: Arc::new(InMemoryVectorStore::new()),
118 recent: InMemoryMemory::new(),
119 }
120 }
121}
122
123impl<E, S> VectorMemory<E, S>
124where
125 E: Embedder,
126 S: VectorStore,
127{
128 pub fn with_components(embedder: Arc<E>, store: Arc<S>) -> Self {
130 Self {
131 embedder,
132 store,
133 recent: InMemoryMemory::new(),
134 }
135 }
136}
137
138#[async_trait]
139impl<E, S> Memory for VectorMemory<E, S>
140where
141 E: Embedder,
142 S: VectorStore,
143{
144 async fn store(&mut self, key: &str, value: Message) -> Result<()> {
145 let vector = self.embedder.embed(&value.text_content()).await?;
146 self.store
147 .upsert(
148 format!("{key}-{}", uuid::Uuid::new_v4()),
149 vector,
150 value.clone(),
151 )
152 .await?;
153 self.recent.store(key, value).await
154 }
155
156 async fn retrieve(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
157 let vector = self.embedder.embed(query).await?;
158 Ok(self
159 .store
160 .search(vector, limit)
161 .await?
162 .into_iter()
163 .map(|result| result.payload)
164 .collect())
165 }
166
167 async fn history(&self) -> Result<Vec<Message>> {
168 self.recent.history().await
169 }
170
171 async fn clear(&mut self) -> Result<()> {
172 self.store.clear().await?;
173 self.recent.clear().await
174 }
175}
176
177#[async_trait]
178impl<E, S> SearchableMemory for VectorMemory<E, S>
179where
180 E: Embedder,
181 S: VectorStore,
182{
183 async fn token_count(&self) -> Result<usize> {
184 Ok(self
185 .recent
186 .history()
187 .await?
188 .into_iter()
189 .map(|message| message.text_content().chars().count() / 4)
190 .sum())
191 }
192
193 async fn search(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
194 self.retrieve(query, limit).await
195 }
196}
197
198fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
199 if left.len() != right.len() || left.is_empty() {
200 return 0.0;
201 }
202
203 let dot = left.iter().zip(right).map(|(l, r)| l * r).sum::<f32>();
204 let left_norm = left.iter().map(|value| value * value).sum::<f32>().sqrt();
205 let right_norm = right.iter().map(|value| value * value).sum::<f32>().sqrt();
206
207 if left_norm == 0.0 || right_norm == 0.0 {
208 0.0
209 } else {
210 dot / (left_norm * right_norm)
211 }
212}