llm_brain/
embeddings.rs

1use std::sync::OnceLock;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use tiktoken_rs::{CoreBPE, cl100k_base};
6
7use crate::error::{LLMBrainError, Result};
8
9/// Cached tokenizer instance for improved performance
10static BPE: OnceLock<CoreBPE> = OnceLock::new();
11
12/// Vector embedding provider trait
13///
14/// Structs implementing this trait can generate vector embeddings for text
15#[async_trait]
16pub trait EmbeddingProvider: Send + Sync {
17    /// Generate embedding vector for a single text
18    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
19
20    /// Generate embedding vectors for multiple texts
21    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
22
23    /// Count tokens in text
24    fn count_tokens(&self, text: &str) -> Result<usize>;
25}
26
27/// Embedding model configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct EmbeddingModelConfig {
30    /// Embedding model name
31    pub model_name: String,
32
33    /// Embedding vector dimensions
34    pub dimensions: usize,
35
36    /// Maximum context window size (in tokens)
37    pub max_context_length: usize,
38}
39
40impl Default for EmbeddingModelConfig {
41    fn default() -> Self {
42        Self {
43            model_name: "text-embedding-3-small".to_owned(),
44            dimensions: 1536,
45            max_context_length: 8191,
46        }
47    }
48}
49
50/// Embedding generation middleware
51///
52/// This struct provides pre-processing and post-processing for embedding
53/// requests, such as:
54/// - Text normalization
55/// - Embedding caching
56/// - Batch processing optimization
57/// - Vector normalization
58pub struct EmbeddingMiddleware<P: EmbeddingProvider> {
59    provider: P,
60    normalize_vectors: bool,
61}
62
63impl<P: EmbeddingProvider> EmbeddingMiddleware<P> {
64    /// Create a new embedding middleware instance
65    pub fn new(provider: P, normalize_vectors: bool) -> Self {
66        Self {
67            provider,
68            normalize_vectors,
69        }
70    }
71
72    /// Initialize tokenizer
73    pub fn initialize_tokenizer() -> Result<()> {
74        BPE.get_or_init(|| cl100k_base().expect("Failed to load cl100k_base tokenizer"));
75        Ok(())
76    }
77
78    /// Normalize text content
79    pub fn normalize_text(&self, text: &str) -> String {
80        // Basic text normalization: trim whitespace, merge multiple spaces into one
81        text.trim().to_owned()
82    }
83
84    /// Normalize vector to make it a unit vector
85    pub fn normalize_vector(&self, vector: &mut [f32]) {
86        if !self.normalize_vectors {
87            return;
88        }
89
90        // Calculate Euclidean norm of the vector
91        let norm = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
92
93        // Avoid division by zero
94        if norm > 1e-10 {
95            for x in vector.iter_mut() {
96                *x /= norm;
97            }
98        }
99    }
100}
101
102#[async_trait]
103impl<P: EmbeddingProvider + Send + Sync> EmbeddingProvider for EmbeddingMiddleware<P> {
104    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
105        let normalized_text = self.normalize_text(text);
106        let mut embedding = self.provider.generate_embedding(&normalized_text).await?;
107        self.normalize_vector(&mut embedding);
108        Ok(embedding)
109    }
110
111    async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
112        let normalized_texts = texts
113            .iter()
114            .map(|text| self.normalize_text(text))
115            .collect::<Vec<String>>();
116
117        let mut embeddings = self.provider.generate_embeddings(normalized_texts).await?;
118
119        for embedding in &mut embeddings {
120            self.normalize_vector(embedding);
121        }
122
123        Ok(embeddings)
124    }
125
126    fn count_tokens(&self, text: &str) -> Result<usize> {
127        self.provider.count_tokens(text)
128    }
129}
130
131/// Text chunk-based embedding strategy
132///
133/// Defines how to handle long text embeddings: splitting, averaging, using
134/// specific parts, etc.
135#[derive(Default)]
136pub enum ChunkingStrategy {
137    /// Maximize use of model context window, no chunking needed
138    #[default]
139    NoChunking,
140
141    /// Split long text into chunks, generate embeddings for each chunk, then
142    /// average
143    ChunkAndAverage {
144        chunk_size: usize,
145        chunk_overlap: usize,
146    },
147
148    /// Only use the first N tokens of the text
149    UsePrefix(usize),
150
151    /// Only use the last N tokens of the text
152    UseSuffix(usize),
153}
154
155/// Long text embedding handler
156pub struct LongTextHandler<P: EmbeddingProvider> {
157    provider: P,
158    model_config: EmbeddingModelConfig,
159    chunking_strategy: ChunkingStrategy,
160}
161
162impl<P: EmbeddingProvider> LongTextHandler<P> {
163    /// Create a new long text handler
164    pub fn new(
165        provider: P, model_config: EmbeddingModelConfig, chunking_strategy: ChunkingStrategy,
166    ) -> Self {
167        Self {
168            provider,
169            model_config,
170            chunking_strategy,
171        }
172    }
173
174    /// Truncate text to maximum context window size
175    pub fn truncate_text(&self, text: &str) -> Result<String> {
176        let token_count = self.provider.count_tokens(text)?;
177
178        if token_count <= self.model_config.max_context_length {
179            return Ok(text.to_owned());
180        }
181
182        // Simple truncation implementation - in practice, can be more precisely
183        // truncated at token boundaries
184        let bpe = BPE.get().expect("BPE Tokenizer not initialized");
185        let tokens = bpe.encode_with_special_tokens(text);
186        let truncated_tokens = tokens[0..self.model_config.max_context_length].to_vec();
187
188        // Decode back to text
189        bpe.decode(truncated_tokens)
190            .map_err(|e| LLMBrainError::InputError(format!("Failed to decode tokens: {e}")))
191    }
192
193    /// Chunk text into pieces
194    pub fn chunk_text(&self, text: &str, chunk_size: usize, overlap: usize) -> Result<Vec<String>> {
195        let bpe = BPE.get().expect("BPE Tokenizer not initialized");
196        let tokens = bpe.encode_with_special_tokens(text);
197
198        if tokens.len() <= chunk_size {
199            return Ok(vec![text.to_owned()]);
200        }
201
202        let mut chunks = Vec::new();
203        let mut start = 0;
204
205        while start < tokens.len() {
206            let end = (start + chunk_size).min(tokens.len());
207            let chunk_tokens = tokens[start..end].to_vec();
208
209            // Decode the current chunk
210            let chunk = bpe
211                .decode(chunk_tokens)
212                .map_err(|e| LLMBrainError::InputError(format!("Failed to decode tokens: {e}")))?;
213
214            chunks.push(chunk);
215
216            if end >= tokens.len() {
217                break;
218            }
219
220            // Calculate the starting position of the next chunk, considering overlap
221            start = end - overlap;
222        }
223
224        Ok(chunks)
225    }
226
227    /// Process embedding vectors
228    pub async fn process_embeddings(&self, text: &str) -> Result<Vec<f32>> {
229        match &self.chunking_strategy {
230            ChunkingStrategy::NoChunking => {
231                let truncated = self.truncate_text(text)?;
232                self.provider.generate_embedding(&truncated).await
233            }
234
235            ChunkingStrategy::ChunkAndAverage {
236                chunk_size,
237                chunk_overlap,
238            } => {
239                let chunks = self.chunk_text(text, *chunk_size, *chunk_overlap)?;
240                if chunks.is_empty() {
241                    return Err(LLMBrainError::InputError(
242                        "No chunks generated from text".to_owned(),
243                    ));
244                }
245
246                let embeddings = self.provider.generate_embeddings(chunks).await?;
247
248                // Average all embedding vectors
249                if embeddings.is_empty() {
250                    return Err(LLMBrainError::ApiError(
251                        "No embeddings generated".to_owned(),
252                    ));
253                }
254
255                let dimensions = embeddings[0].len();
256                let mut average = vec![0.0; dimensions];
257
258                for embedding in &embeddings {
259                    for (i, &value) in embedding.iter().enumerate() {
260                        average[i] += value / embeddings.len() as f32;
261                    }
262                }
263
264                Ok(average)
265            }
266
267            ChunkingStrategy::UsePrefix(size) => {
268                let bpe = BPE.get().expect("BPE Tokenizer not initialized");
269                let tokens = bpe.encode_with_special_tokens(text);
270
271                let prefix_tokens = tokens.iter().take(*size).cloned().collect::<Vec<_>>();
272
273                let prefix = bpe.decode(prefix_tokens).map_err(|e| {
274                    LLMBrainError::InputError(format!("Failed to decode tokens: {e}"))
275                })?;
276
277                self.provider.generate_embedding(&prefix).await
278            }
279
280            ChunkingStrategy::UseSuffix(size) => {
281                let bpe = BPE.get().expect("BPE Tokenizer not initialized");
282                let tokens = bpe.encode_with_special_tokens(text);
283
284                let suffix_tokens = tokens.iter().rev().take(*size).cloned().collect::<Vec<_>>();
285
286                let suffix = bpe.decode(suffix_tokens).map_err(|e| {
287                    LLMBrainError::InputError(format!("Failed to decode tokens: {e}"))
288                })?;
289
290                self.provider.generate_embedding(&suffix).await
291            }
292        }
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    // Mock embedding provider implementation for testing
301    struct MockEmbeddingProvider;
302
303    #[async_trait]
304    impl EmbeddingProvider for MockEmbeddingProvider {
305        async fn generate_embedding(&self, _text: &str) -> Result<Vec<f32>> {
306            // Return a fixed test embedding vector
307            Ok(vec![0.1, 0.2, 0.3, 0.4])
308        }
309
310        async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
311            // Generate test embeddings for each input text
312            let mut result = Vec::new();
313            for _ in 0..texts.len() {
314                result.push(vec![0.1, 0.2, 0.3, 0.4]);
315            }
316            Ok(result)
317        }
318
319        fn count_tokens(&self, text: &str) -> Result<usize> {
320            // Simple estimation: assume each word is one token
321            Ok(text.split_whitespace().count())
322        }
323    }
324
325    #[tokio::test]
326    async fn test_embedding_middleware() {
327        let provider = MockEmbeddingProvider;
328        let middleware = EmbeddingMiddleware::new(provider, true);
329
330        // Test vector normalization
331        let embedding = middleware.generate_embedding("test text").await.unwrap();
332
333        // Calculate the length of the normalized vector (should be close to 1.0)
334        let norm = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
335        assert!((norm - 1.0).abs() < 1e-6);
336    }
337}