Skip to main content

autoagents_core/vector_store/
mod.rs

1pub use request::VectorSearchRequest;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use uuid::Uuid;
6
7use crate::document::Document;
8use crate::embeddings::{Embed, Embedding, EmbeddingError, SharedEmbeddingProvider, TextEmbedder};
9use crate::one_or_many::OneOrMany;
10use crate::vector_store::request::{FilterError, SearchFilter};
11
12pub mod in_memory_store;
13pub mod request;
14
15#[derive(Debug, thiserror::Error)]
16pub enum VectorStoreError {
17    #[error("Embedding error: {0}")]
18    EmbeddingError(#[from] EmbeddingError),
19
20    #[error("Json error: {0}")]
21    JsonError(#[from] serde_json::Error),
22
23    #[error("Filter error: {0}")]
24    FilterError(#[from] FilterError),
25
26    #[error("Datastore error: {0}")]
27    DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),
28
29    #[error("Error while building VectorSearchRequest: {0}")]
30    BuilderError(String),
31}
32
33#[async_trait]
34pub trait VectorStoreIndex: Send + Sync {
35    type Filter: SearchFilter + Send + Sync;
36
37    async fn insert_documents<T>(&self, documents: Vec<T>) -> Result<(), VectorStoreError>
38    where
39        T: Embed + Serialize + Send + Sync + Clone;
40
41    async fn insert_documents_with_ids<T>(
42        &self,
43        documents: Vec<(String, T)>,
44    ) -> Result<(), VectorStoreError>
45    where
46        T: Embed + Serialize + Send + Sync + Clone;
47
48    async fn top_n<T>(
49        &self,
50        req: VectorSearchRequest<Self::Filter>,
51    ) -> Result<Vec<(f64, String, T)>, VectorStoreError>
52    where
53        T: for<'de> Deserialize<'de> + Send + Sync;
54
55    async fn top_n_ids(
56        &self,
57        req: VectorSearchRequest<Self::Filter>,
58    ) -> Result<Vec<(f64, String)>, VectorStoreError>;
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct VectorStoreOutput {
63    pub score: f64,
64    pub id: String,
65    pub document: Document,
66}
67
68#[derive(Debug, Clone)]
69pub struct PreparedDocument {
70    pub id: String,
71    pub raw: serde_json::Value,
72    pub embeddings: OneOrMany<Embedding>,
73}
74
75pub async fn embed_documents<T>(
76    provider: &SharedEmbeddingProvider,
77    documents: Vec<(String, T)>,
78) -> Result<Vec<PreparedDocument>, VectorStoreError>
79where
80    T: Embed + Serialize + Send + Sync + Clone,
81{
82    let mut all_texts = Vec::new();
83    let mut ranges = Vec::new();
84    let mut raws = Vec::new();
85    let mut ids = Vec::new();
86
87    for (id, doc) in documents.iter() {
88        let mut embedder = TextEmbedder::new();
89        doc.embed(&mut embedder).map_err(|err| {
90            VectorStoreError::EmbeddingError(EmbeddingError::EmbedFailure(err.to_string()))
91        })?;
92
93        if embedder.is_empty() {
94            return Err(VectorStoreError::EmbeddingError(EmbeddingError::Empty));
95        }
96
97        let start = all_texts.len();
98        let count = embedder.len();
99        all_texts.extend(embedder.into_parts());
100        ranges.push((start, count));
101        raws.push(serde_json::to_value(doc)?);
102        ids.push(id.clone());
103    }
104
105    let vectors = provider
106        .embed(all_texts.clone())
107        .await
108        .map_err(EmbeddingError::Provider)?;
109
110    let mut prepared = Vec::with_capacity(ids.len());
111    for ((id, raw), (start, count)) in ids.into_iter().zip(raws).zip(ranges.into_iter()) {
112        let embeddings: Vec<Embedding> = vectors[start..start + count]
113            .iter()
114            .enumerate()
115            .map(|(offset, vector)| Embedding {
116                document: all_texts[start + offset].clone(),
117                vec: vector.clone(),
118            })
119            .collect();
120
121        prepared.push(PreparedDocument {
122            id,
123            raw,
124            embeddings: OneOrMany::from(embeddings),
125        });
126    }
127
128    Ok(prepared)
129}
130
131pub fn normalize_id(id: Option<String>) -> String {
132    id.unwrap_or_else(|| Uuid::new_v4().to_string())
133}