autoagents_core/vector_store/
mod.rs1pub use request::VectorSearchRequest;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use uuid::Uuid;
7
8use crate::document::Document;
9use crate::embeddings::{Embed, Embedding, EmbeddingError, SharedEmbeddingProvider, TextEmbedder};
10use crate::one_or_many::OneOrMany;
11use crate::vector_store::request::{FilterError, SearchFilter};
12
13pub mod in_memory_store;
14pub mod request;
15
16pub const DEFAULT_VECTOR_NAME: &str = "default";
17
18#[derive(Debug, thiserror::Error)]
19pub enum VectorStoreError {
20 #[error("Embedding error: {0}")]
21 EmbeddingError(#[from] EmbeddingError),
22
23 #[error("Json error: {0}")]
24 JsonError(#[from] serde_json::Error),
25
26 #[error("Filter error: {0}")]
27 FilterError(#[from] FilterError),
28
29 #[error("Datastore error: {0}")]
30 DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
31
32 #[error("Error while building VectorSearchRequest: {0}")]
33 BuilderError(String),
34}
35
36#[async_trait]
37pub trait VectorStoreIndex: Send + Sync {
38 type Filter: SearchFilter + Send + Sync;
39
40 async fn insert_documents<T>(&self, documents: Vec<T>) -> Result<(), VectorStoreError>
41 where
42 T: Embed + Serialize + Send + Sync + Clone;
43
44 async fn insert_documents_with_ids<T>(
45 &self,
46 documents: Vec<(String, T)>,
47 ) -> Result<(), VectorStoreError>
48 where
49 T: Embed + Serialize + Send + Sync + Clone;
50
51 async fn top_n<T>(
52 &self,
53 req: VectorSearchRequest<Self::Filter>,
54 ) -> Result<Vec<(f64, String, T)>, VectorStoreError>
55 where
56 T: for<'de> Deserialize<'de> + Send + Sync;
57
58 async fn top_n_ids(
59 &self,
60 req: VectorSearchRequest<Self::Filter>,
61 ) -> Result<Vec<(f64, String)>, VectorStoreError>;
62
63 async fn insert_documents_with_named_vectors<T>(
64 &self,
65 documents: Vec<NamedVectorDocument<T>>,
66 ) -> Result<(), VectorStoreError>
67 where
68 T: Serialize + Send + Sync + Clone;
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct VectorStoreOutput {
73 pub score: f64,
74 pub id: String,
75 pub document: Document,
76}
77
78#[derive(Debug, Clone)]
79pub struct PreparedDocument {
80 pub id: String,
81 pub raw: serde_json::Value,
82 pub embeddings: OneOrMany<Embedding>,
83}
84
85#[derive(Debug, Clone)]
86pub struct NamedVectorDocument<T> {
87 pub id: String,
88 pub raw: T,
89 pub vectors: HashMap<String, String>,
90}
91
92#[derive(Debug, Clone)]
93pub struct PreparedNamedVectorDocument {
94 pub id: String,
95 pub raw: serde_json::Value,
96 pub vectors: HashMap<String, Vec<f32>>,
97}
98
99pub async fn embed_documents<T>(
100 provider: &SharedEmbeddingProvider,
101 documents: Vec<(String, T)>,
102) -> Result<Vec<PreparedDocument>, VectorStoreError>
103where
104 T: Embed + Serialize + Send + Sync + Clone,
105{
106 let mut all_texts = Vec::new();
107 let mut ranges = Vec::new();
108 let mut raws = Vec::new();
109 let mut ids = Vec::new();
110
111 for (id, doc) in documents.iter() {
112 let mut embedder = TextEmbedder::default();
113 doc.embed(&mut embedder).map_err(|err| {
114 VectorStoreError::EmbeddingError(EmbeddingError::EmbedFailure(err.to_string()))
115 })?;
116
117 if embedder.is_empty() {
118 return Err(VectorStoreError::EmbeddingError(EmbeddingError::Empty));
119 }
120
121 let start = all_texts.len();
122 let count = embedder.len();
123 all_texts.extend(embedder.into_parts());
124 ranges.push((start, count));
125 raws.push(serde_json::to_value(doc)?);
126 ids.push(id.clone());
127 }
128
129 let vectors = provider
130 .embed(all_texts.clone())
131 .await
132 .map_err(EmbeddingError::Provider)?;
133
134 let mut prepared = Vec::with_capacity(ids.len());
135 let mut vectors_iter = vectors.into_iter();
136 let mut expected_start = 0usize;
137 for ((id, raw), (start, count)) in ids.into_iter().zip(raws).zip(ranges.into_iter()) {
138 if start != expected_start {
139 return Err(VectorStoreError::EmbeddingError(
140 EmbeddingError::EmbedFailure("embedding ranges are inconsistent".into()),
141 ));
142 }
143
144 let mut embeddings = Vec::with_capacity(count);
145 for offset in 0..count {
146 let Some(vector) = vectors_iter.next() else {
147 return Err(VectorStoreError::EmbeddingError(
148 EmbeddingError::EmbedFailure(
149 "embedding provider returned fewer vectors than expected".into(),
150 ),
151 ));
152 };
153
154 embeddings.push(Embedding {
155 document: all_texts[start + offset].clone(),
156 vec: vector.into(),
157 });
158 }
159 expected_start += count;
160
161 prepared.push(PreparedDocument {
162 id,
163 raw,
164 embeddings: OneOrMany::from(embeddings),
165 });
166 }
167
168 Ok(prepared)
169}
170
171pub async fn embed_named_documents<T>(
172 provider: &SharedEmbeddingProvider,
173 documents: Vec<NamedVectorDocument<T>>,
174) -> Result<Vec<PreparedNamedVectorDocument>, VectorStoreError>
175where
176 T: Serialize + Send + Sync + Clone,
177{
178 let mut all_texts = Vec::new();
179 let mut ranges = Vec::new();
180 let mut raws = Vec::new();
181 let mut ids = Vec::new();
182 let mut names_by_doc = Vec::new();
183
184 for doc in documents {
185 if doc.vectors.is_empty() {
186 return Err(VectorStoreError::EmbeddingError(EmbeddingError::Empty));
187 }
188
189 let mut names = Vec::with_capacity(doc.vectors.len());
190 let start = all_texts.len();
191
192 for (name, text) in doc.vectors {
193 names.push(name);
194 all_texts.push(text);
195 }
196
197 ranges.push((start, names.len()));
198 names_by_doc.push(names);
199 raws.push(serde_json::to_value(doc.raw)?);
200 ids.push(doc.id);
201 }
202
203 let vectors = provider
204 .embed(all_texts.clone())
205 .await
206 .map_err(EmbeddingError::Provider)?;
207
208 let mut prepared = Vec::with_capacity(ids.len());
209 let mut vectors_iter = vectors.into_iter();
210 let mut expected_start = 0usize;
211 for (((id, raw), (start, count)), names) in ids
212 .into_iter()
213 .zip(raws)
214 .zip(ranges.into_iter())
215 .zip(names_by_doc.into_iter())
216 {
217 if start != expected_start {
218 return Err(VectorStoreError::EmbeddingError(
219 EmbeddingError::EmbedFailure("embedding ranges are inconsistent".into()),
220 ));
221 }
222
223 let mut mapped = HashMap::with_capacity(count);
224 for name in names.into_iter() {
225 let Some(vector) = vectors_iter.next() else {
226 return Err(VectorStoreError::EmbeddingError(
227 EmbeddingError::EmbedFailure(
228 "embedding provider returned fewer vectors than expected".into(),
229 ),
230 ));
231 };
232 mapped.insert(name, vector);
233 }
234 expected_start += count;
235
236 prepared.push(PreparedNamedVectorDocument {
237 id,
238 raw,
239 vectors: mapped,
240 });
241 }
242
243 Ok(prepared)
244}
245
246pub fn normalize_id(id: Option<String>) -> String {
247 id.unwrap_or_else(|| Uuid::new_v4().to_string())
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use crate::document::Document;
254 use crate::embeddings::{Embed, EmbedError, TextEmbedder};
255 use autoagents_llm::embedding::EmbeddingProvider;
256 use autoagents_llm::error::LLMError;
257 use serde::Serialize;
258 use std::sync::Arc;
259
260 #[derive(Debug, Clone)]
261 struct DummyEmbeddingProvider {
262 vectors: Vec<Vec<f32>>,
263 }
264
265 #[async_trait::async_trait]
266 impl EmbeddingProvider for DummyEmbeddingProvider {
267 async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
268 Ok(self.vectors.clone())
269 }
270 }
271
272 #[derive(Debug, Clone, Serialize)]
273 struct MultiPartDoc {
274 parts: Vec<String>,
275 }
276
277 impl Embed for MultiPartDoc {
278 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
279 for part in &self.parts {
280 embedder.embed(part.clone());
281 }
282 Ok(())
283 }
284 }
285
286 #[derive(Debug, Clone, Serialize)]
287 struct EmptyDoc;
288
289 impl Embed for EmptyDoc {
290 fn embed(&self, _embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
291 Ok(())
292 }
293 }
294
295 #[test]
296 fn test_normalize_id_none_generates_uuid() {
297 let id = normalize_id(None);
298 assert!(!id.is_empty());
299 assert!(uuid::Uuid::parse_str(&id).is_ok());
300 }
301
302 #[test]
303 fn test_normalize_id_some_returns_value() {
304 let id = normalize_id(Some("custom-id".to_string()));
305 assert_eq!(id, "custom-id");
306 }
307
308 #[tokio::test]
309 async fn test_embed_documents_with_mock() {
310 use crate::tests::MockLLMProvider;
311 let provider: SharedEmbeddingProvider = Arc::new(MockLLMProvider {});
312 let docs = vec![("id1".to_string(), Document::new("hello"))];
313 let result = embed_documents(&provider, docs).await;
314 assert!(result.is_ok());
315 let prepared = result.unwrap();
316 assert_eq!(prepared.len(), 1);
317 assert_eq!(prepared[0].id, "id1");
318 }
319
320 #[tokio::test]
321 async fn test_embed_documents_empty_embedder() {
322 let provider: SharedEmbeddingProvider =
323 Arc::new(DummyEmbeddingProvider { vectors: vec![] });
324 let docs = vec![("id1".to_string(), EmptyDoc)];
325 let err = embed_documents(&provider, docs).await.unwrap_err();
326 assert!(err.to_string().contains("No content to embed"));
327 }
328
329 #[tokio::test]
330 async fn test_embed_documents_fewer_vectors_than_expected() {
331 let provider: SharedEmbeddingProvider = Arc::new(DummyEmbeddingProvider {
332 vectors: vec![vec![0.1_f32]],
333 });
334 let docs = vec![(
335 "id1".to_string(),
336 MultiPartDoc {
337 parts: vec!["a".to_string(), "b".to_string()],
338 },
339 )];
340 let err = embed_documents(&provider, docs).await.unwrap_err();
341 assert!(err.to_string().contains("fewer vectors"));
342 }
343
344 #[tokio::test]
345 async fn test_embed_named_documents_success() {
346 let provider: SharedEmbeddingProvider = Arc::new(DummyEmbeddingProvider {
347 vectors: vec![vec![0.1_f32], vec![0.2_f32]],
348 });
349 let docs = vec![NamedVectorDocument {
350 id: "doc-1".to_string(),
351 raw: "raw".to_string(),
352 vectors: HashMap::from([
353 ("title".to_string(), "hello".to_string()),
354 ("body".to_string(), "world".to_string()),
355 ]),
356 }];
357 let prepared = embed_named_documents(&provider, docs).await.unwrap();
358 assert_eq!(prepared.len(), 1);
359 assert_eq!(prepared[0].vectors.len(), 2);
360 }
361
362 #[tokio::test]
363 async fn test_embed_named_documents_empty_vectors() {
364 let provider: SharedEmbeddingProvider =
365 Arc::new(DummyEmbeddingProvider { vectors: vec![] });
366 let docs = vec![NamedVectorDocument {
367 id: "doc-1".to_string(),
368 raw: "raw".to_string(),
369 vectors: HashMap::new(),
370 }];
371 let err = embed_named_documents(&provider, docs).await.unwrap_err();
372 assert!(err.to_string().contains("No content to embed"));
373 }
374
375 #[tokio::test]
376 async fn test_embed_named_documents_fewer_vectors() {
377 let provider: SharedEmbeddingProvider = Arc::new(DummyEmbeddingProvider {
378 vectors: vec![vec![0.1_f32]],
379 });
380 let docs = vec![NamedVectorDocument {
381 id: "doc-1".to_string(),
382 raw: "raw".to_string(),
383 vectors: HashMap::from([
384 ("title".to_string(), "hello".to_string()),
385 ("body".to_string(), "world".to_string()),
386 ]),
387 }];
388 let err = embed_named_documents(&provider, docs).await.unwrap_err();
389 assert!(err.to_string().contains("fewer vectors"));
390 }
391}