Skip to main content

autoagents_core/embeddings/
mod.rs

1use std::sync::Arc;
2
3use autoagents_llm::embedding::EmbeddingProvider;
4use autoagents_llm::error::LLMError;
5use serde::{Deserialize, Serialize};
6
7use crate::one_or_many::OneOrMany;
8
9pub mod distance;
10
11pub type SharedEmbeddingProvider = Arc<dyn EmbeddingProvider + Send + Sync>;
12
13#[derive(Debug, Clone, Serialize, Deserialize, Default)]
14pub struct Embedding {
15    pub document: String,
16    pub vec: Vec<f32>,
17}
18
19impl distance::VectorDistance for Embedding {
20    fn cosine_similarity(&self, other: &Self, normalize: bool) -> f32 {
21        self.vec.cosine_similarity(&other.vec, normalize)
22    }
23}
24
25#[derive(Debug, thiserror::Error)]
26pub enum EmbeddingError {
27    #[error("Embedding provider error: {0}")]
28    Provider(#[from] LLMError),
29
30    #[error("No content to embed")]
31    Empty,
32
33    #[error("Embedding failed: {0}")]
34    EmbedFailure(String),
35
36    #[error("Serialization error: {0}")]
37    Serialization(#[from] serde_json::Error),
38}
39
40#[derive(Debug, Default)]
41pub struct TextEmbedder {
42    parts: Vec<String>,
43}
44
45impl TextEmbedder {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    pub fn embed(&mut self, text: impl Into<String>) {
51        self.parts.push(text.into());
52    }
53
54    pub fn len(&self) -> usize {
55        self.parts.len()
56    }
57
58    pub fn is_empty(&self) -> bool {
59        self.parts.is_empty()
60    }
61
62    pub fn parts(&self) -> &[String] {
63        &self.parts
64    }
65
66    pub fn into_parts(self) -> Vec<String> {
67        self.parts
68    }
69}
70
71#[derive(Debug, thiserror::Error)]
72pub enum EmbedError {
73    #[error("{0}")]
74    Message(String),
75}
76
77pub trait Embed {
78    fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>;
79}
80
81pub struct EmbeddingsBuilder<T> {
82    provider: SharedEmbeddingProvider,
83    documents: Vec<T>,
84}
85
86impl<T> EmbeddingsBuilder<T>
87where
88    T: Embed + Clone,
89{
90    pub fn new(provider: SharedEmbeddingProvider) -> Self {
91        Self {
92            provider,
93            documents: Vec::new(),
94        }
95    }
96
97    pub fn documents(mut self, docs: impl IntoIterator<Item = T>) -> Result<Self, EmbeddingError> {
98        self.documents.extend(docs);
99        if self.documents.is_empty() {
100            return Err(EmbeddingError::Empty);
101        }
102        Ok(self)
103    }
104
105    pub async fn build(self) -> Result<Vec<(T, OneOrMany<Embedding>)>, EmbeddingError> {
106        if self.documents.is_empty() {
107            return Err(EmbeddingError::Empty);
108        }
109
110        let mut texts = Vec::new();
111        let mut ranges = Vec::new();
112        for doc in &self.documents {
113            let mut embedder = TextEmbedder::new();
114            doc.embed(&mut embedder)
115                .map_err(|err| EmbeddingError::EmbedFailure(err.to_string()))?;
116
117            if embedder.is_empty() {
118                return Err(EmbeddingError::Empty);
119            }
120
121            let start = texts.len();
122            let count = embedder.len();
123            let parts = embedder.into_parts();
124            texts.extend(parts);
125            ranges.push((start, count));
126        }
127
128        let text_copy = texts.clone();
129        let vectors = self
130            .provider
131            .embed(text_copy)
132            .await
133            .map_err(EmbeddingError::Provider)?;
134
135        let mut cursor = 0usize;
136        let mut results = Vec::with_capacity(self.documents.len());
137        for (doc, (start, len)) in self.documents.into_iter().zip(ranges.into_iter()) {
138            let slice = &vectors[start..start + len];
139            let embeddings: Vec<Embedding> = slice
140                .iter()
141                .enumerate()
142                .map(|(offset, vector)| Embedding {
143                    document: texts[start + offset].clone(),
144                    vec: vector.clone(),
145                })
146                .collect();
147            cursor += len;
148            results.push((doc, OneOrMany::from(embeddings)));
149        }
150
151        if cursor == 0 {
152            return Err(EmbeddingError::Empty);
153        }
154
155        Ok(results)
156    }
157}