post_cortex_embeddings/embeddings/
engine.rs1use anyhow::Result;
20use dashmap::DashMap;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::time::Duration;
24use tokio::time::timeout;
25use tracing::{debug, error, info, warn};
26
27use super::backend::EmbeddingBackend;
28#[cfg(feature = "bert")]
29use super::backends::BertBackend;
30#[cfg(feature = "model2vec")]
31use super::backends::Model2VecBackend;
32use super::backends::StaticHashBackend;
33use super::concurrency::ConcurrencyController;
34use super::config::{EmbeddingConfig, EmbeddingModelType};
35use super::pool::MemoryPool;
36
37pub struct LocalEmbeddingEngine {
39 backend: Arc<dyn EmbeddingBackend>,
40 config: EmbeddingConfig,
41 current_batch_size: AtomicUsize,
43 batch_performance_cache: Arc<DashMap<usize, f64>>,
45 concurrency_controller: Arc<ConcurrencyController>,
47}
48
49impl LocalEmbeddingEngine {
50 pub async fn new(config: EmbeddingConfig) -> Result<Self> {
52 info!(
53 "Initializing embedding engine with model: {:?}",
54 config.model_type
55 );
56
57 let dimension = config.model_type.embedding_dimension();
58 let backend: Arc<dyn EmbeddingBackend> = if config.model_type.is_model2vec() {
59 #[cfg(feature = "model2vec")]
60 {
61 Arc::new(Model2VecBackend::load(config.model_type).await?)
62 }
63 #[cfg(not(feature = "model2vec"))]
64 {
65 return Err(anyhow::anyhow!(
66 "Model type {:?} requires the `model2vec` feature, which is disabled. \
67 Rebuild post-cortex-embeddings with `--features model2vec` or pick a \
68 different EmbeddingModelType.",
69 config.model_type
70 ));
71 }
72 } else if config.model_type.is_bert_based() {
73 #[cfg(feature = "bert")]
74 {
75 Arc::new(BertBackend::load(config.model_type).await?)
76 }
77 #[cfg(not(feature = "bert"))]
78 {
79 return Err(anyhow::anyhow!(
80 "Model type {:?} requires the `bert` feature, which is disabled. \
81 Rebuild post-cortex-embeddings with `--features bert` or pick a \
82 different EmbeddingModelType.",
83 config.model_type
84 ));
85 }
86 } else {
87 let pool = Arc::new(MemoryPool::new(config.memory_pool_size, dimension));
88 Arc::new(StaticHashBackend::new(dimension, pool))
89 };
90
91 let concurrency_controller =
92 Arc::new(ConcurrencyController::new(config.max_concurrent_ops));
93
94 Ok(Self {
95 backend,
96 current_batch_size: AtomicUsize::new(config.max_batch_size),
97 batch_performance_cache: Arc::new(DashMap::new()),
98 concurrency_controller,
99 config,
100 })
101 }
102
103 pub fn current_batch_size(&self) -> usize {
105 self.current_batch_size.load(Ordering::Relaxed)
106 }
107
108 pub fn embedding_dimension(&self) -> usize {
110 self.backend.embedding_dimension()
111 }
112
113 pub fn is_bert_based(&self) -> bool {
115 self.backend.is_bert_based()
116 }
117
118 pub async fn encode_text(&self, text: &str) -> Result<Vec<f32>> {
120 let embeddings = self.encode_batch(vec![text.to_string()]).await?;
121 embeddings
122 .into_iter()
123 .next()
124 .ok_or_else(|| anyhow::anyhow!("No embeddings generated"))
125 }
126
127 pub async fn encode_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
129 if texts.is_empty() {
130 return Ok(Vec::new());
131 }
132
133 if !self.backend.is_bert_based() {
140 if matches!(
141 self.config.model_type,
142 EmbeddingModelType::StaticSimilarityMRL
143 ) {
144 warn!(
145 "Using StaticHashBackend for model_type {:?} — semantic search will NOT \
146 work correctly! Pick PotionMultilingual (default) or a BERT variant.",
147 self.config.model_type
148 );
149 }
150 return self.backend.process_batch(texts).await;
151 }
152
153 info!(
154 "Using BERT embeddings for model_type: {:?}, encoding {} texts",
155 self.config.model_type,
156 texts.len()
157 );
158
159 let total_start_time = std::time::Instant::now();
160 let result = self.encode_batch_with_controls(texts.clone()).await;
161
162 let total_time = total_start_time.elapsed();
163 debug!("Encoded {} texts in {:?}", texts.len(), total_time);
164
165 result
166 }
167
168 async fn encode_batch_with_controls(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
170 let _permit = match self.concurrency_controller.try_acquire() {
171 Some(permit) => permit,
172 None => self.concurrency_controller.acquire().await?,
173 };
174
175 let batch_size = if self.config.adaptive_batching {
176 self.get_adaptive_batch_size(texts.len()).await
177 } else {
178 self.current_batch_size()
179 };
180
181 let mut all_embeddings = Vec::new();
182
183 for chunk in texts.chunks(batch_size) {
184 let start_time = std::time::Instant::now();
185
186 let batch_result = timeout(
187 Duration::from_secs(self.config.operation_timeout_secs),
188 self.backend.process_batch(chunk.to_vec()),
189 )
190 .await;
191
192 match batch_result {
193 Ok(Ok(batch_embeddings)) => {
194 all_embeddings.extend(batch_embeddings);
195 let time_ms = start_time.elapsed().as_millis() as f64;
196 self.update_batch_performance(chunk.len(), time_ms, 1.0);
197 }
198 Ok(Err(e)) => {
199 error!("Batch processing failed: {}", e);
200 self.update_batch_performance(
201 chunk.len(),
202 start_time.elapsed().as_millis() as f64,
203 0.0,
204 );
205 return Err(e);
206 }
207 Err(_) => {
208 error!("Batch processing timed out");
209 return Err(anyhow::anyhow!(
210 "Batch processing timed out after {} seconds",
211 self.config.operation_timeout_secs
212 ));
213 }
214 }
215 }
216
217 Ok(all_embeddings)
218 }
219
220 async fn get_adaptive_batch_size(&self, text_count: usize) -> usize {
222 let base_size = self.current_batch_size();
223
224 if text_count <= base_size {
225 return text_count;
226 }
227
228 let recent_performance: Vec<f64> = self
229 .batch_performance_cache
230 .iter()
231 .take(10)
232 .map(|entry| *entry.value())
233 .collect();
234
235 let avg_performance = if recent_performance.is_empty() {
236 0.8 } else {
238 recent_performance.iter().sum::<f64>() / recent_performance.len() as f64
239 };
240
241 if avg_performance > 0.9 {
242 (base_size as f64 * 1.2) as usize
243 } else if avg_performance < 0.7 {
244 (base_size as f64 * 0.8) as usize
245 } else {
246 base_size
247 }
248 }
249
250 fn update_batch_performance(&self, batch_size: usize, time_ms: f64, success_rate: f64) {
252 let metric = success_rate / (time_ms / batch_size as f64);
253 self.batch_performance_cache.insert(batch_size, metric);
254
255 loop {
256 let current = self.current_batch_size.load(Ordering::Acquire);
257
258 let new_size = if success_rate > 0.9 && time_ms < 1000.0 {
259 (current as f64 * 1.1) as usize
260 } else if success_rate < 0.7 || time_ms > 2000.0 {
261 (current as f64 * 0.9) as usize
262 } else {
263 return; };
265
266 let clamped = new_size.clamp(8, 256);
267
268 match self.current_batch_size.compare_exchange_weak(
269 current,
270 clamped,
271 Ordering::AcqRel,
272 Ordering::Relaxed,
273 ) {
274 Ok(_) => return,
275 Err(_) => {
276 std::hint::spin_loop();
277 continue;
278 }
279 }
280 }
281 }
282
283 pub fn current_concurrency_load(&self) -> usize {
285 self.concurrency_controller.current_load()
286 }
287
288 pub fn get_concurrency_stats(&self) -> (usize, usize) {
290 (
291 self.concurrency_controller.current_load(),
292 self.concurrency_controller.max_capacity(),
293 )
294 }
295}