Skip to main content

autoagents_core/vector_store/
mod.rs

1pub use request::VectorSearchRequest;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use uuid::Uuid;
7
8use crate::document::Document;
9use crate::embeddings::{Embed, Embedding, EmbeddingError, SharedEmbeddingProvider, TextEmbedder};
10use crate::one_or_many::OneOrMany;
11use crate::vector_store::request::{FilterError, SearchFilter};
12
13pub mod in_memory_store;
14pub mod request;
15
16pub const DEFAULT_VECTOR_NAME: &str = "default";
17
18#[derive(Debug, thiserror::Error)]
19pub enum VectorStoreError {
20    #[error("Embedding error: {0}")]
21    EmbeddingError(#[from] EmbeddingError),
22
23    #[error("Json error: {0}")]
24    JsonError(#[from] serde_json::Error),
25
26    #[error("Filter error: {0}")]
27    FilterError(#[from] FilterError),
28
29    #[error("Datastore error: {0}")]
30    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
31
32    #[error("Error while building VectorSearchRequest: {0}")]
33    BuilderError(String),
34}
35
36#[async_trait]
37pub trait VectorStoreIndex: Send + Sync {
38    type Filter: SearchFilter + Send + Sync;
39
40    async fn insert_documents<T>(&self, documents: Vec<T>) -> Result<(), VectorStoreError>
41    where
42        T: Embed + Serialize + Send + Sync + Clone;
43
44    async fn insert_documents_with_ids<T>(
45        &self,
46        documents: Vec<(String, T)>,
47    ) -> Result<(), VectorStoreError>
48    where
49        T: Embed + Serialize + Send + Sync + Clone;
50
51    async fn top_n<T>(
52        &self,
53        req: VectorSearchRequest<Self::Filter>,
54    ) -> Result<Vec<(f64, String, T)>, VectorStoreError>
55    where
56        T: for<'de> Deserialize<'de> + Send + Sync;
57
58    async fn top_n_ids(
59        &self,
60        req: VectorSearchRequest<Self::Filter>,
61    ) -> Result<Vec<(f64, String)>, VectorStoreError>;
62
63    async fn insert_documents_with_named_vectors<T>(
64        &self,
65        documents: Vec<NamedVectorDocument<T>>,
66    ) -> Result<(), VectorStoreError>
67    where
68        T: Serialize + Send + Sync + Clone;
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct VectorStoreOutput {
73    pub score: f64,
74    pub id: String,
75    pub document: Document,
76}
77
78#[derive(Debug, Clone)]
79pub struct PreparedDocument {
80    pub id: String,
81    pub raw: serde_json::Value,
82    pub embeddings: OneOrMany<Embedding>,
83}
84
85#[derive(Debug, Clone)]
86pub struct NamedVectorDocument<T> {
87    pub id: String,
88    pub raw: T,
89    pub vectors: HashMap<String, String>,
90}
91
92#[derive(Debug, Clone)]
93pub struct PreparedNamedVectorDocument {
94    pub id: String,
95    pub raw: serde_json::Value,
96    pub vectors: HashMap<String, Vec<f32>>,
97}
98
99pub async fn embed_documents<T>(
100    provider: &SharedEmbeddingProvider,
101    documents: Vec<(String, T)>,
102) -> Result<Vec<PreparedDocument>, VectorStoreError>
103where
104    T: Embed + Serialize + Send + Sync + Clone,
105{
106    let mut all_texts = Vec::new();
107    let mut ranges = Vec::new();
108    let mut raws = Vec::new();
109    let mut ids = Vec::new();
110
111    for (id, doc) in documents.iter() {
112        let mut embedder = TextEmbedder::default();
113        doc.embed(&mut embedder).map_err(|err| {
114            VectorStoreError::EmbeddingError(EmbeddingError::EmbedFailure(err.to_string()))
115        })?;
116
117        if embedder.is_empty() {
118            return Err(VectorStoreError::EmbeddingError(EmbeddingError::Empty));
119        }
120
121        let start = all_texts.len();
122        let count = embedder.len();
123        all_texts.extend(embedder.into_parts());
124        ranges.push((start, count));
125        raws.push(serde_json::to_value(doc)?);
126        ids.push(id.clone());
127    }
128
129    let vectors = provider
130        .embed(all_texts.clone())
131        .await
132        .map_err(EmbeddingError::Provider)?;
133
134    let mut prepared = Vec::with_capacity(ids.len());
135    let mut vectors_iter = vectors.into_iter();
136    let mut expected_start = 0usize;
137    for ((id, raw), (start, count)) in ids.into_iter().zip(raws).zip(ranges.into_iter()) {
138        if start != expected_start {
139            return Err(VectorStoreError::EmbeddingError(
140                EmbeddingError::EmbedFailure("embedding ranges are inconsistent".into()),
141            ));
142        }
143
144        let mut embeddings = Vec::with_capacity(count);
145        for offset in 0..count {
146            let Some(vector) = vectors_iter.next() else {
147                return Err(VectorStoreError::EmbeddingError(
148                    EmbeddingError::EmbedFailure(
149                        "embedding provider returned fewer vectors than expected".into(),
150                    ),
151                ));
152            };
153
154            embeddings.push(Embedding {
155                document: all_texts[start + offset].clone(),
156                vec: vector.into(),
157            });
158        }
159        expected_start += count;
160
161        prepared.push(PreparedDocument {
162            id,
163            raw,
164            embeddings: OneOrMany::from(embeddings),
165        });
166    }
167
168    Ok(prepared)
169}
170
171pub async fn embed_named_documents<T>(
172    provider: &SharedEmbeddingProvider,
173    documents: Vec<NamedVectorDocument<T>>,
174) -> Result<Vec<PreparedNamedVectorDocument>, VectorStoreError>
175where
176    T: Serialize + Send + Sync + Clone,
177{
178    let mut all_texts = Vec::new();
179    let mut ranges = Vec::new();
180    let mut raws = Vec::new();
181    let mut ids = Vec::new();
182    let mut names_by_doc = Vec::new();
183
184    for doc in documents {
185        if doc.vectors.is_empty() {
186            return Err(VectorStoreError::EmbeddingError(EmbeddingError::Empty));
187        }
188
189        let mut names = Vec::with_capacity(doc.vectors.len());
190        let start = all_texts.len();
191
192        for (name, text) in doc.vectors {
193            names.push(name);
194            all_texts.push(text);
195        }
196
197        ranges.push((start, names.len()));
198        names_by_doc.push(names);
199        raws.push(serde_json::to_value(doc.raw)?);
200        ids.push(doc.id);
201    }
202
203    let vectors = provider
204        .embed(all_texts.clone())
205        .await
206        .map_err(EmbeddingError::Provider)?;
207
208    let mut prepared = Vec::with_capacity(ids.len());
209    let mut vectors_iter = vectors.into_iter();
210    let mut expected_start = 0usize;
211    for (((id, raw), (start, count)), names) in ids
212        .into_iter()
213        .zip(raws)
214        .zip(ranges.into_iter())
215        .zip(names_by_doc.into_iter())
216    {
217        if start != expected_start {
218            return Err(VectorStoreError::EmbeddingError(
219                EmbeddingError::EmbedFailure("embedding ranges are inconsistent".into()),
220            ));
221        }
222
223        let mut mapped = HashMap::with_capacity(count);
224        for name in names.into_iter() {
225            let Some(vector) = vectors_iter.next() else {
226                return Err(VectorStoreError::EmbeddingError(
227                    EmbeddingError::EmbedFailure(
228                        "embedding provider returned fewer vectors than expected".into(),
229                    ),
230                ));
231            };
232            mapped.insert(name, vector);
233        }
234        expected_start += count;
235
236        prepared.push(PreparedNamedVectorDocument {
237            id,
238            raw,
239            vectors: mapped,
240        });
241    }
242
243    Ok(prepared)
244}
245
246pub fn normalize_id(id: Option<String>) -> String {
247    id.unwrap_or_else(|| Uuid::new_v4().to_string())
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::document::Document;
254    use crate::embeddings::{Embed, EmbedError, TextEmbedder};
255    use autoagents_llm::embedding::EmbeddingProvider;
256    use autoagents_llm::error::LLMError;
257    use serde::Serialize;
258    use std::sync::Arc;
259
260    #[derive(Debug, Clone)]
261    struct DummyEmbeddingProvider {
262        vectors: Vec<Vec<f32>>,
263    }
264
265    #[async_trait::async_trait]
266    impl EmbeddingProvider for DummyEmbeddingProvider {
267        async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
268            Ok(self.vectors.clone())
269        }
270    }
271
272    #[derive(Debug, Clone, Serialize)]
273    struct MultiPartDoc {
274        parts: Vec<String>,
275    }
276
277    impl Embed for MultiPartDoc {
278        fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
279            for part in &self.parts {
280                embedder.embed(part.clone());
281            }
282            Ok(())
283        }
284    }
285
286    #[derive(Debug, Clone, Serialize)]
287    struct EmptyDoc;
288
289    impl Embed for EmptyDoc {
290        fn embed(&self, _embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
291            Ok(())
292        }
293    }
294
295    #[test]
296    fn test_normalize_id_none_generates_uuid() {
297        let id = normalize_id(None);
298        assert!(!id.is_empty());
299        assert!(uuid::Uuid::parse_str(&id).is_ok());
300    }
301
302    #[test]
303    fn test_normalize_id_some_returns_value() {
304        let id = normalize_id(Some("custom-id".to_string()));
305        assert_eq!(id, "custom-id");
306    }
307
308    #[tokio::test]
309    async fn test_embed_documents_with_mock() {
310        use crate::tests::MockLLMProvider;
311        let provider: SharedEmbeddingProvider = Arc::new(MockLLMProvider {});
312        let docs = vec![("id1".to_string(), Document::new("hello"))];
313        let result = embed_documents(&provider, docs).await;
314        assert!(result.is_ok());
315        let prepared = result.unwrap();
316        assert_eq!(prepared.len(), 1);
317        assert_eq!(prepared[0].id, "id1");
318    }
319
320    #[tokio::test]
321    async fn test_embed_documents_empty_embedder() {
322        let provider: SharedEmbeddingProvider =
323            Arc::new(DummyEmbeddingProvider { vectors: vec![] });
324        let docs = vec![("id1".to_string(), EmptyDoc)];
325        let err = embed_documents(&provider, docs).await.unwrap_err();
326        assert!(err.to_string().contains("No content to embed"));
327    }
328
329    #[tokio::test]
330    async fn test_embed_documents_fewer_vectors_than_expected() {
331        let provider: SharedEmbeddingProvider = Arc::new(DummyEmbeddingProvider {
332            vectors: vec![vec![0.1_f32]],
333        });
334        let docs = vec![(
335            "id1".to_string(),
336            MultiPartDoc {
337                parts: vec!["a".to_string(), "b".to_string()],
338            },
339        )];
340        let err = embed_documents(&provider, docs).await.unwrap_err();
341        assert!(err.to_string().contains("fewer vectors"));
342    }
343
344    #[tokio::test]
345    async fn test_embed_named_documents_success() {
346        let provider: SharedEmbeddingProvider = Arc::new(DummyEmbeddingProvider {
347            vectors: vec![vec![0.1_f32], vec![0.2_f32]],
348        });
349        let docs = vec![NamedVectorDocument {
350            id: "doc-1".to_string(),
351            raw: "raw".to_string(),
352            vectors: HashMap::from([
353                ("title".to_string(), "hello".to_string()),
354                ("body".to_string(), "world".to_string()),
355            ]),
356        }];
357        let prepared = embed_named_documents(&provider, docs).await.unwrap();
358        assert_eq!(prepared.len(), 1);
359        assert_eq!(prepared[0].vectors.len(), 2);
360    }
361
362    #[tokio::test]
363    async fn test_embed_named_documents_empty_vectors() {
364        let provider: SharedEmbeddingProvider =
365            Arc::new(DummyEmbeddingProvider { vectors: vec![] });
366        let docs = vec![NamedVectorDocument {
367            id: "doc-1".to_string(),
368            raw: "raw".to_string(),
369            vectors: HashMap::new(),
370        }];
371        let err = embed_named_documents(&provider, docs).await.unwrap_err();
372        assert!(err.to_string().contains("No content to embed"));
373    }
374
375    #[tokio::test]
376    async fn test_embed_named_documents_fewer_vectors() {
377        let provider: SharedEmbeddingProvider = Arc::new(DummyEmbeddingProvider {
378            vectors: vec![vec![0.1_f32]],
379        });
380        let docs = vec![NamedVectorDocument {
381            id: "doc-1".to_string(),
382            raw: "raw".to_string(),
383            vectors: HashMap::from([
384                ("title".to_string(), "hello".to_string()),
385                ("body".to_string(), "world".to_string()),
386            ]),
387        }];
388        let err = embed_named_documents(&provider, docs).await.unwrap_err();
389        assert!(err.to_string().contains("fewer vectors"));
390    }
391}