autoagents_core/embeddings/
mod.rs1use std::sync::Arc;
2
3use autoagents_llm::embedding::EmbeddingProvider;
4use autoagents_llm::error::LLMError;
5use serde::{Deserialize, Serialize};
6
7use crate::one_or_many::OneOrMany;
8
9pub mod distance;
10
11pub type SharedEmbeddingProvider = Arc<dyn EmbeddingProvider + Send + Sync>;
12
13#[derive(Debug, Clone, Serialize, Deserialize, Default)]
14pub struct Embedding {
15 pub document: String,
16 pub vec: Vec<f32>,
17}
18
19impl distance::VectorDistance for Embedding {
20 fn cosine_similarity(&self, other: &Self, normalize: bool) -> f32 {
21 self.vec.cosine_similarity(&other.vec, normalize)
22 }
23}
24
25#[derive(Debug, thiserror::Error)]
26pub enum EmbeddingError {
27 #[error("Embedding provider error: {0}")]
28 Provider(#[from] LLMError),
29
30 #[error("No content to embed")]
31 Empty,
32
33 #[error("Embedding failed: {0}")]
34 EmbedFailure(String),
35
36 #[error("Serialization error: {0}")]
37 Serialization(#[from] serde_json::Error),
38}
39
40#[derive(Debug, Default)]
41pub struct TextEmbedder {
42 parts: Vec<String>,
43}
44
45impl TextEmbedder {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn embed(&mut self, text: impl Into<String>) {
51 self.parts.push(text.into());
52 }
53
54 pub fn len(&self) -> usize {
55 self.parts.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
59 self.parts.is_empty()
60 }
61
62 pub fn parts(&self) -> &[String] {
63 &self.parts
64 }
65
66 pub fn into_parts(self) -> Vec<String> {
67 self.parts
68 }
69}
70
71#[derive(Debug, thiserror::Error)]
72pub enum EmbedError {
73 #[error("{0}")]
74 Message(String),
75}
76
77pub trait Embed {
78 fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>;
79}
80
81pub struct EmbeddingsBuilder<T> {
82 provider: SharedEmbeddingProvider,
83 documents: Vec<T>,
84}
85
86impl<T> EmbeddingsBuilder<T>
87where
88 T: Embed + Clone,
89{
90 pub fn new(provider: SharedEmbeddingProvider) -> Self {
91 Self {
92 provider,
93 documents: Vec::new(),
94 }
95 }
96
97 pub fn documents(mut self, docs: impl IntoIterator<Item = T>) -> Result<Self, EmbeddingError> {
98 self.documents.extend(docs);
99 if self.documents.is_empty() {
100 return Err(EmbeddingError::Empty);
101 }
102 Ok(self)
103 }
104
105 pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
106 if self.documents.is_empty() {
107 return Err(EmbeddingError::Empty);
108 }
109
110 let mut texts = Vec::new();
111 let mut ranges = Vec::new();
112 for doc in &self.documents {
113 let mut embedder = TextEmbedder::new();
114 doc.embed(&mut embedder)
115 .map_err(|err| EmbeddingError::EmbedFailure(err.to_string()))?;
116
117 if embedder.is_empty() {
118 return Err(EmbeddingError::Empty);
119 }
120
121 let start = texts.len();
122 let count = embedder.len();
123 let parts = embedder.into_parts();
124 texts.extend(parts);
125 ranges.push((start, count));
126 }
127
128 let text_copy = texts.clone();
129 let vectors = self
130 .provider
131 .embed(text_copy)
132 .await
133 .map_err(EmbeddingError::Provider)?;
134
135 let mut cursor = 0usize;
136 let mut results = Vec::with_capacity(self.documents.len());
137 for (doc, (start, len)) in self.documents.into_iter().zip(ranges.into_iter()) {
138 let slice = &vectors[start..start + len];
139 let embeddings: Vec<Embedding> = slice
140 .iter()
141 .enumerate()
142 .map(|(offset, vector)| Embedding {
143 document: texts[start + offset].clone(),
144 vec: vector.clone(),
145 })
146 .collect();
147 cursor += len;
148 results.push((doc, OneOrMany::from(embeddings)));
149 }
150
151 if cursor == 0 {
152 return Err(EmbeddingError::Empty);
153 }
154
155 Ok(results)
156 }
157}