autoagents_core/embeddings/
mod.rs1use std::sync::Arc;
2
3use autoagents_llm::embedding::EmbeddingProvider;
4use autoagents_llm::error::LLMError;
5use serde::{Deserialize, Serialize};
6
7use crate::one_or_many::OneOrMany;
8
9pub mod distance;
10
11pub type SharedEmbeddingProvider = Arc<dyn EmbeddingProvider + Send + Sync>;
12pub type VecArc = Arc<[f32]>;
13
14#[derive(Debug, Clone, Serialize, Deserialize, Default)]
15pub struct Embedding {
16 pub document: String,
17 pub vec: VecArc,
18}
19
20impl distance::VectorDistance for Embedding {
21 fn cosine_similarity(&self, other: &Self, normalize: bool) -> f32 {
22 self.vec
23 .as_ref()
24 .cosine_similarity(other.vec.as_ref(), normalize)
25 }
26}
27
28#[derive(Debug, thiserror::Error)]
29pub enum EmbeddingError {
30 #[error("Embedding provider error: {0}")]
31 Provider(#[from] LLMError),
32
33 #[error("No content to embed")]
34 Empty,
35
36 #[error("Embedding failed: {0}")]
37 EmbedFailure(String),
38
39 #[error("Serialization error: {0}")]
40 Serialization(#[from] serde_json::Error),
41}
42
43#[derive(Debug, Default)]
44pub struct TextEmbedder {
45 parts: Vec<String>,
46}
47
48impl TextEmbedder {
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn embed(&mut self, text: impl Into<String>) {
54 self.parts.push(text.into());
55 }
56
57 pub fn len(&self) -> usize {
58 self.parts.len()
59 }
60
61 pub fn is_empty(&self) -> bool {
62 self.parts.is_empty()
63 }
64
65 pub fn parts(&self) -> &[String] {
66 &self.parts
67 }
68
69 pub fn into_parts(self) -> Vec<String> {
70 self.parts
71 }
72}
73
74#[derive(Debug, thiserror::Error)]
75pub enum EmbedError {
76 #[error("{0}")]
77 Message(String),
78}
79
80pub trait Embed {
81 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>;
82}
83
84#[cfg(test)]
85impl Embed for String {
86 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> {
87 embedder.embed(self.clone());
88 Ok(())
89 }
90}
91
92pub struct EmbeddingsBuilder<T> {
93 provider: SharedEmbeddingProvider,
94 documents: Vec<T>,
95}
96
97impl<T> EmbeddingsBuilder<T>
98where
99 T: Embed + Clone,
100{
101 pub fn new(provider: SharedEmbeddingProvider) -> Self {
102 Self {
103 provider,
104 documents: Vec::default(),
105 }
106 }
107
108 pub fn documents(mut self, docs: impl IntoIterator<Item = T>) -> Result<Self, EmbeddingError> {
109 self.documents.extend(docs);
110 if self.documents.is_empty() {
111 return Err(EmbeddingError::Empty);
112 }
113 Ok(self)
114 }
115
116 pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
117 if self.documents.is_empty() {
118 return Err(EmbeddingError::Empty);
119 }
120
121 let mut texts = Vec::default();
122 let mut ranges = Vec::default();
123 for doc in &self.documents {
124 let mut embedder = TextEmbedder::default();
125 doc.embed(&mut embedder)
126 .map_err(|err| EmbeddingError::EmbedFailure(err.to_string()))?;
127
128 if embedder.is_empty() {
129 return Err(EmbeddingError::Empty);
130 }
131
132 let start = texts.len();
133 let count = embedder.len();
134 let parts = embedder.into_parts();
135 texts.extend(parts);
136 ranges.push((start, count));
137 }
138
139 let text_copy = texts.clone();
140 let vectors = self
141 .provider
142 .embed(text_copy)
143 .await
144 .map_err(EmbeddingError::Provider)?;
145
146 let mut cursor = 0usize;
147 let mut results = Vec::with_capacity(self.documents.len());
148 for (doc, (start, len)) in self.documents.into_iter().zip(ranges.into_iter()) {
149 let slice = &vectors[start..start + len];
150 let embeddings: Vec<Embedding> = slice
151 .iter()
152 .enumerate()
153 .map(|(offset, vector)| Embedding {
154 document: texts[start + offset].clone(),
155 vec: vector.clone().into(),
156 })
157 .collect();
158 cursor += len;
159 results.push((doc, OneOrMany::from(embeddings)));
160 }
161
162 if cursor == 0 {
163 return Err(EmbeddingError::Empty);
164 }
165
166 Ok(results)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::distance::VectorDistance;
173 use super::*;
174
175 #[test]
176 fn test_text_embedder_embed_and_parts() {
177 let mut embedder = TextEmbedder::default();
178 embedder.embed("hello");
179 embedder.embed("world");
180 assert_eq!(embedder.len(), 2);
181 assert!(!embedder.is_empty());
182 assert_eq!(embedder.parts(), &["hello", "world"]);
183 }
184
185 #[test]
186 fn test_embedding_cosine_similarity_identical() {
187 let a = Embedding {
188 document: "a".to_string(),
189 vec: vec![1.0, 0.0, 0.0].into(),
190 };
191 let b = Embedding {
192 document: "b".to_string(),
193 vec: vec![1.0, 0.0, 0.0].into(),
194 };
195 let sim = a.cosine_similarity(&b, true);
196 assert!((sim - 1.0).abs() < 1e-6);
197 }
198
199 #[test]
200 fn test_embedding_cosine_similarity_orthogonal() {
201 let a = Embedding {
202 document: "a".to_string(),
203 vec: vec![1.0, 0.0, 0.0].into(),
204 };
205 let b = Embedding {
206 document: "b".to_string(),
207 vec: vec![0.0, 1.0, 0.0].into(),
208 };
209 let sim = a.cosine_similarity(&b, true);
210 assert!(sim.abs() < 1e-6);
211 }
212
213 #[tokio::test]
214 async fn test_embeddings_builder_empty_error() {
215 use crate::tests::MockLLMProvider;
216 let provider: SharedEmbeddingProvider = Arc::new(MockLLMProvider {});
217 let builder = EmbeddingsBuilder::<String>::new(provider);
218 let result = builder.build().await;
219 assert!(result.is_err());
220 }
221
222 #[tokio::test]
223 async fn test_embeddings_builder_success() {
224 use crate::tests::MockLLMProvider;
225 let provider: SharedEmbeddingProvider = Arc::new(MockLLMProvider {});
226 let result = EmbeddingsBuilder::new(provider)
227 .documents(vec!["hello".to_string()])
228 .unwrap()
229 .build()
230 .await;
231 assert!(result.is_ok());
232 let items = result.unwrap();
233 assert_eq!(items.len(), 1);
234 assert_eq!(items[0].0, "hello");
235 }
236
237 #[test]
238 fn test_embeddings_builder_documents_empty_error() {
239 use crate::tests::MockLLMProvider;
240 let provider: SharedEmbeddingProvider = Arc::new(MockLLMProvider {});
241 let result = EmbeddingsBuilder::<String>::new(provider).documents(Vec::<String>::new());
242 assert!(result.is_err());
243 }
244}