1use std::sync::OnceLock;
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use tiktoken_rs::{CoreBPE, cl100k_base};
6
7use crate::error::{LLMBrainError, Result};
8
9static BPE: OnceLock<CoreBPE> = OnceLock::new();
11
12#[async_trait]
16pub trait EmbeddingProvider: Send + Sync {
17 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
19
20 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>>;
22
23 fn count_tokens(&self, text: &str) -> Result<usize>;
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct EmbeddingModelConfig {
30 pub model_name: String,
32
33 pub dimensions: usize,
35
36 pub max_context_length: usize,
38}
39
40impl Default for EmbeddingModelConfig {
41 fn default() -> Self {
42 Self {
43 model_name: "text-embedding-3-small".to_owned(),
44 dimensions: 1536,
45 max_context_length: 8191,
46 }
47 }
48}
49
50pub struct EmbeddingMiddleware<P: EmbeddingProvider> {
59 provider: P,
60 normalize_vectors: bool,
61}
62
63impl<P: EmbeddingProvider> EmbeddingMiddleware<P> {
64 pub fn new(provider: P, normalize_vectors: bool) -> Self {
66 Self {
67 provider,
68 normalize_vectors,
69 }
70 }
71
72 pub fn initialize_tokenizer() -> Result<()> {
74 BPE.get_or_init(|| cl100k_base().expect("Failed to load cl100k_base tokenizer"));
75 Ok(())
76 }
77
78 pub fn normalize_text(&self, text: &str) -> String {
80 text.trim().to_owned()
82 }
83
84 pub fn normalize_vector(&self, vector: &mut [f32]) {
86 if !self.normalize_vectors {
87 return;
88 }
89
90 let norm = vector.iter().map(|&x| x * x).sum::<f32>().sqrt();
92
93 if norm > 1e-10 {
95 for x in vector.iter_mut() {
96 *x /= norm;
97 }
98 }
99 }
100}
101
102#[async_trait]
103impl<P: EmbeddingProvider + Send + Sync> EmbeddingProvider for EmbeddingMiddleware<P> {
104 async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
105 let normalized_text = self.normalize_text(text);
106 let mut embedding = self.provider.generate_embedding(&normalized_text).await?;
107 self.normalize_vector(&mut embedding);
108 Ok(embedding)
109 }
110
111 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
112 let normalized_texts = texts
113 .iter()
114 .map(|text| self.normalize_text(text))
115 .collect::<Vec<String>>();
116
117 let mut embeddings = self.provider.generate_embeddings(normalized_texts).await?;
118
119 for embedding in &mut embeddings {
120 self.normalize_vector(embedding);
121 }
122
123 Ok(embeddings)
124 }
125
126 fn count_tokens(&self, text: &str) -> Result<usize> {
127 self.provider.count_tokens(text)
128 }
129}
130
131#[derive(Default)]
136pub enum ChunkingStrategy {
137 #[default]
139 NoChunking,
140
141 ChunkAndAverage {
144 chunk_size: usize,
145 chunk_overlap: usize,
146 },
147
148 UsePrefix(usize),
150
151 UseSuffix(usize),
153}
154
155pub struct LongTextHandler<P: EmbeddingProvider> {
157 provider: P,
158 model_config: EmbeddingModelConfig,
159 chunking_strategy: ChunkingStrategy,
160}
161
162impl<P: EmbeddingProvider> LongTextHandler<P> {
163 pub fn new(
165 provider: P, model_config: EmbeddingModelConfig, chunking_strategy: ChunkingStrategy,
166 ) -> Self {
167 Self {
168 provider,
169 model_config,
170 chunking_strategy,
171 }
172 }
173
174 pub fn truncate_text(&self, text: &str) -> Result<String> {
176 let token_count = self.provider.count_tokens(text)?;
177
178 if token_count <= self.model_config.max_context_length {
179 return Ok(text.to_owned());
180 }
181
182 let bpe = BPE.get().expect("BPE Tokenizer not initialized");
185 let tokens = bpe.encode_with_special_tokens(text);
186 let truncated_tokens = tokens[0..self.model_config.max_context_length].to_vec();
187
188 bpe.decode(truncated_tokens)
190 .map_err(|e| LLMBrainError::InputError(format!("Failed to decode tokens: {e}")))
191 }
192
193 pub fn chunk_text(&self, text: &str, chunk_size: usize, overlap: usize) -> Result<Vec<String>> {
195 let bpe = BPE.get().expect("BPE Tokenizer not initialized");
196 let tokens = bpe.encode_with_special_tokens(text);
197
198 if tokens.len() <= chunk_size {
199 return Ok(vec![text.to_owned()]);
200 }
201
202 let mut chunks = Vec::new();
203 let mut start = 0;
204
205 while start < tokens.len() {
206 let end = (start + chunk_size).min(tokens.len());
207 let chunk_tokens = tokens[start..end].to_vec();
208
209 let chunk = bpe
211 .decode(chunk_tokens)
212 .map_err(|e| LLMBrainError::InputError(format!("Failed to decode tokens: {e}")))?;
213
214 chunks.push(chunk);
215
216 if end >= tokens.len() {
217 break;
218 }
219
220 start = end - overlap;
222 }
223
224 Ok(chunks)
225 }
226
227 pub async fn process_embeddings(&self, text: &str) -> Result<Vec<f32>> {
229 match &self.chunking_strategy {
230 ChunkingStrategy::NoChunking => {
231 let truncated = self.truncate_text(text)?;
232 self.provider.generate_embedding(&truncated).await
233 }
234
235 ChunkingStrategy::ChunkAndAverage {
236 chunk_size,
237 chunk_overlap,
238 } => {
239 let chunks = self.chunk_text(text, *chunk_size, *chunk_overlap)?;
240 if chunks.is_empty() {
241 return Err(LLMBrainError::InputError(
242 "No chunks generated from text".to_owned(),
243 ));
244 }
245
246 let embeddings = self.provider.generate_embeddings(chunks).await?;
247
248 if embeddings.is_empty() {
250 return Err(LLMBrainError::ApiError(
251 "No embeddings generated".to_owned(),
252 ));
253 }
254
255 let dimensions = embeddings[0].len();
256 let mut average = vec![0.0; dimensions];
257
258 for embedding in &embeddings {
259 for (i, &value) in embedding.iter().enumerate() {
260 average[i] += value / embeddings.len() as f32;
261 }
262 }
263
264 Ok(average)
265 }
266
267 ChunkingStrategy::UsePrefix(size) => {
268 let bpe = BPE.get().expect("BPE Tokenizer not initialized");
269 let tokens = bpe.encode_with_special_tokens(text);
270
271 let prefix_tokens = tokens.iter().take(*size).cloned().collect::<Vec<_>>();
272
273 let prefix = bpe.decode(prefix_tokens).map_err(|e| {
274 LLMBrainError::InputError(format!("Failed to decode tokens: {e}"))
275 })?;
276
277 self.provider.generate_embedding(&prefix).await
278 }
279
280 ChunkingStrategy::UseSuffix(size) => {
281 let bpe = BPE.get().expect("BPE Tokenizer not initialized");
282 let tokens = bpe.encode_with_special_tokens(text);
283
284 let suffix_tokens = tokens.iter().rev().take(*size).cloned().collect::<Vec<_>>();
285
286 let suffix = bpe.decode(suffix_tokens).map_err(|e| {
287 LLMBrainError::InputError(format!("Failed to decode tokens: {e}"))
288 })?;
289
290 self.provider.generate_embedding(&suffix).await
291 }
292 }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 struct MockEmbeddingProvider;
302
303 #[async_trait]
304 impl EmbeddingProvider for MockEmbeddingProvider {
305 async fn generate_embedding(&self, _text: &str) -> Result<Vec<f32>> {
306 Ok(vec![0.1, 0.2, 0.3, 0.4])
308 }
309
310 async fn generate_embeddings(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
311 let mut result = Vec::new();
313 for _ in 0..texts.len() {
314 result.push(vec![0.1, 0.2, 0.3, 0.4]);
315 }
316 Ok(result)
317 }
318
319 fn count_tokens(&self, text: &str) -> Result<usize> {
320 Ok(text.split_whitespace().count())
322 }
323 }
324
325 #[tokio::test]
326 async fn test_embedding_middleware() {
327 let provider = MockEmbeddingProvider;
328 let middleware = EmbeddingMiddleware::new(provider, true);
329
330 let embedding = middleware.generate_embedding("test text").await.unwrap();
332
333 let norm = embedding.iter().map(|&x| x * x).sum::<f32>().sqrt();
335 assert!((norm - 1.0).abs() < 1e-6);
336 }
337}