Skip to main content

autoagents_core/embeddings/
mod.rs

1use std::sync::Arc;
2
3use autoagents_llm::embedding::EmbeddingProvider;
4use autoagents_llm::error::LLMError;
5use serde::{Deserialize, Serialize};
6
7use crate::one_or_many::OneOrMany;
8
9pub mod distance;
10
11pub type SharedEmbeddingProvider = Arc<dyn EmbeddingProvider + Send + Sync>;
12pub type VecArc = Arc<[f32]>;
13
14#[derive(Debug, Clone, Serialize, Deserialize, Default)]
15pub struct Embedding {
16    pub document: String,
17    pub vec: VecArc,
18}
19
20impl distance::VectorDistance for Embedding {
21    fn cosine_similarity(&self, other: &Self, normalize: bool) -> f32 {
22        self.vec
23            .as_ref()
24            .cosine_similarity(other.vec.as_ref(), normalize)
25    }
26}
27
28#[derive(Debug, thiserror::Error)]
29pub enum EmbeddingError {
30    #[error("Embedding provider error: {0}")]
31    Provider(#[from] LLMError),
32
33    #[error("No content to embed")]
34    Empty,
35
36    #[error("Embedding failed: {0}")]
37    EmbedFailure(String),
38
39    #[error("Serialization error: {0}")]
40    Serialization(#[from] serde_json::Error),
41}
42
43#[derive(Debug, Default)]
44pub struct TextEmbedder {
45    parts: Vec<String>,
46}
47
48impl TextEmbedder {
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    pub fn embed(&mut self, text: impl Into<String>) {
54        self.parts.push(text.into());
55    }
56
57    pub fn len(&self) -> usize {
58        self.parts.len()
59    }
60
61    pub fn is_empty(&self) -> bool {
62        self.parts.is_empty()
63    }
64
65    pub fn parts(&self) -> &[String] {
66        &self.parts
67    }
68
69    pub fn into_parts(self) -> Vec<String> {
70        self.parts
71    }
72}
73
74#[derive(Debug, thiserror::Error)]
75pub enum EmbedError {
76    #[error("{0}")]
77    Message(String),
78}
79
80pub trait Embed {
81    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>;
82}
83
84#[cfg(test)]
85impl Embed for String {
86    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
87        embedder.embed(self.clone());
88        Ok(())
89    }
90}
91
92pub struct EmbeddingsBuilder<T> {
93    provider: SharedEmbeddingProvider,
94    documents: Vec<T>,
95}
96
97impl<T> EmbeddingsBuilder<T>
98where
99    T: Embed + Clone,
100{
101    pub fn new(provider: SharedEmbeddingProvider) -> Self {
102        Self {
103            provider,
104            documents: Vec::default(),
105        }
106    }
107
108    pub fn documents(mut self, docs: impl IntoIterator<Item = T>) -> Result<Self, EmbeddingError> {
109        self.documents.extend(docs);
110        if self.documents.is_empty() {
111            return Err(EmbeddingError::Empty);
112        }
113        Ok(self)
114    }
115
116    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
117        if self.documents.is_empty() {
118            return Err(EmbeddingError::Empty);
119        }
120
121        let mut texts = Vec::default();
122        let mut ranges = Vec::default();
123        for doc in &self.documents {
124            let mut embedder = TextEmbedder::default();
125            doc.embed(&mut embedder)
126                .map_err(|err| EmbeddingError::EmbedFailure(err.to_string()))?;
127
128            if embedder.is_empty() {
129                return Err(EmbeddingError::Empty);
130            }
131
132            let start = texts.len();
133            let count = embedder.len();
134            let parts = embedder.into_parts();
135            texts.extend(parts);
136            ranges.push((start, count));
137        }
138
139        let text_copy = texts.clone();
140        let vectors = self
141            .provider
142            .embed(text_copy)
143            .await
144            .map_err(EmbeddingError::Provider)?;
145
146        let mut cursor = 0usize;
147        let mut results = Vec::with_capacity(self.documents.len());
148        for (doc, (start, len)) in self.documents.into_iter().zip(ranges.into_iter()) {
149            let slice = &vectors[start..start + len];
150            let embeddings: Vec<Embedding> = slice
151                .iter()
152                .enumerate()
153                .map(|(offset, vector)| Embedding {
154                    document: texts[start + offset].clone(),
155                    vec: vector.clone().into(),
156                })
157                .collect();
158            cursor += len;
159            results.push((doc, OneOrMany::from(embeddings)));
160        }
161
162        if cursor == 0 {
163            return Err(EmbeddingError::Empty);
164        }
165
166        Ok(results)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::distance::VectorDistance;
173    use super::*;
174
175    #[test]
176    fn test_text_embedder_embed_and_parts() {
177        let mut embedder = TextEmbedder::default();
178        embedder.embed("hello");
179        embedder.embed("world");
180        assert_eq!(embedder.len(), 2);
181        assert!(!embedder.is_empty());
182        assert_eq!(embedder.parts(), &["hello", "world"]);
183    }
184
185    #[test]
186    fn test_embedding_cosine_similarity_identical() {
187        let a = Embedding {
188            document: "a".to_string(),
189            vec: vec![1.0, 0.0, 0.0].into(),
190        };
191        let b = Embedding {
192            document: "b".to_string(),
193            vec: vec![1.0, 0.0, 0.0].into(),
194        };
195        let sim = a.cosine_similarity(&b, true);
196        assert!((sim - 1.0).abs() < 1e-6);
197    }
198
199    #[test]
200    fn test_embedding_cosine_similarity_orthogonal() {
201        let a = Embedding {
202            document: "a".to_string(),
203            vec: vec![1.0, 0.0, 0.0].into(),
204        };
205        let b = Embedding {
206            document: "b".to_string(),
207            vec: vec![0.0, 1.0, 0.0].into(),
208        };
209        let sim = a.cosine_similarity(&b, true);
210        assert!(sim.abs() < 1e-6);
211    }
212
213    #[tokio::test]
214    async fn test_embeddings_builder_empty_error() {
215        use crate::tests::MockLLMProvider;
216        let provider: SharedEmbeddingProvider = Arc::new(MockLLMProvider {});
217        let builder = EmbeddingsBuilder::<String>::new(provider);
218        let result = builder.build().await;
219        assert!(result.is_err());
220    }
221
222    #[tokio::test]
223    async fn test_embeddings_builder_success() {
224        use crate::tests::MockLLMProvider;
225        let provider: SharedEmbeddingProvider = Arc::new(MockLLMProvider {});
226        let result = EmbeddingsBuilder::new(provider)
227            .documents(vec!["hello".to_string()])
228            .unwrap()
229            .build()
230            .await;
231        assert!(result.is_ok());
232        let items = result.unwrap();
233        assert_eq!(items.len(), 1);
234        assert_eq!(items[0].0, "hello");
235    }
236
237    #[test]
238    fn test_embeddings_builder_documents_empty_error() {
239        use crate::tests::MockLLMProvider;
240        let provider: SharedEmbeddingProvider = Arc::new(MockLLMProvider {});
241        let result = EmbeddingsBuilder::<String>::new(provider).documents(Vec::<String>::new());
242        assert!(result.is_err());
243    }
244}