Skip to main content

post_cortex_embeddings/embeddings/
engine.rs

1// Copyright (c) 2025 Julius ML
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12
13//! `LocalEmbeddingEngine` — the public entry point.
14//!
15//! Selects a backend based on [`EmbeddingConfig::model_type`] and wraps the BERT
16//! path with concurrency control, timeouts, and adaptive batch sizing. The static
17//! path is dispatched directly.
18
19use anyhow::Result;
20use dashmap::DashMap;
21use std::sync::Arc;
22use std::sync::atomic::{AtomicUsize, Ordering};
23use std::time::Duration;
24use tokio::time::timeout;
25use tracing::{debug, error, info, warn};
26
27use super::backend::EmbeddingBackend;
28#[cfg(feature = "bert")]
29use super::backends::BertBackend;
30#[cfg(feature = "model2vec")]
31use super::backends::Model2VecBackend;
32use super::backends::StaticHashBackend;
33use super::concurrency::ConcurrencyController;
34use super::config::{EmbeddingConfig, EmbeddingModelType};
35use super::pool::MemoryPool;
36
37/// Local embedding engine.
38pub struct LocalEmbeddingEngine {
39    backend: Arc<dyn EmbeddingBackend>,
40    config: EmbeddingConfig,
41    /// Current batch size (atomic — adapts based on runtime performance).
42    current_batch_size: AtomicUsize,
43    /// Performance metric per observed batch size — drives adaptive batching.
44    batch_performance_cache: Arc<DashMap<usize, f64>>,
45    /// Concurrency gate for the BERT path.
46    concurrency_controller: Arc<ConcurrencyController>,
47}
48
49impl LocalEmbeddingEngine {
50    /// Create a new embedding engine with the given configuration.
51    pub async fn new(config: EmbeddingConfig) -> Result<Self> {
52        info!(
53            "Initializing embedding engine with model: {:?}",
54            config.model_type
55        );
56
57        let dimension = config.model_type.embedding_dimension();
58        let backend: Arc<dyn EmbeddingBackend> = if config.model_type.is_model2vec() {
59            #[cfg(feature = "model2vec")]
60            {
61                Arc::new(Model2VecBackend::load(config.model_type).await?)
62            }
63            #[cfg(not(feature = "model2vec"))]
64            {
65                return Err(anyhow::anyhow!(
66                    "Model type {:?} requires the `model2vec` feature, which is disabled. \
67                     Rebuild post-cortex-embeddings with `--features model2vec` or pick a \
68                     different EmbeddingModelType.",
69                    config.model_type
70                ));
71            }
72        } else if config.model_type.is_bert_based() {
73            #[cfg(feature = "bert")]
74            {
75                Arc::new(BertBackend::load(config.model_type).await?)
76            }
77            #[cfg(not(feature = "bert"))]
78            {
79                return Err(anyhow::anyhow!(
80                    "Model type {:?} requires the `bert` feature, which is disabled. \
81                     Rebuild post-cortex-embeddings with `--features bert` or pick a \
82                     different EmbeddingModelType.",
83                    config.model_type
84                ));
85            }
86        } else {
87            let pool = Arc::new(MemoryPool::new(config.memory_pool_size, dimension));
88            Arc::new(StaticHashBackend::new(dimension, pool))
89        };
90
91        let concurrency_controller =
92            Arc::new(ConcurrencyController::new(config.max_concurrent_ops));
93
94        Ok(Self {
95            backend,
96            current_batch_size: AtomicUsize::new(config.max_batch_size),
97            batch_performance_cache: Arc::new(DashMap::new()),
98            concurrency_controller,
99            config,
100        })
101    }
102
103    /// Get current batch size.
104    pub fn current_batch_size(&self) -> usize {
105        self.current_batch_size.load(Ordering::Relaxed)
106    }
107
108    /// Get embedding dimension.
109    pub fn embedding_dimension(&self) -> usize {
110        self.backend.embedding_dimension()
111    }
112
113    /// Check if the active backend is BERT-based.
114    pub fn is_bert_based(&self) -> bool {
115        self.backend.is_bert_based()
116    }
117
118    /// Encode a single text into an embedding.
119    pub async fn encode_text(&self, text: &str) -> Result<Vec<f32>> {
120        let embeddings = self.encode_batch(vec![text.to_string()]).await?;
121        embeddings
122            .into_iter()
123            .next()
124            .ok_or_else(|| anyhow::anyhow!("No embeddings generated"))
125    }
126
127    /// Encode a batch of texts into embeddings.
128    pub async fn encode_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
129        if texts.is_empty() {
130            return Ok(Vec::new());
131        }
132
133        // Non-BERT path: dispatch directly — no concurrency gating
134        // required because both Model2Vec and the hash fallback are
135        // ms-cheap. The hash fallback is the only case that does *not*
136        // produce real semantic embeddings, so the warning is scoped to
137        // that variant; Model2Vec is a legitimate static-embedding
138        // backend and silently uses the same direct path.
139        if !self.backend.is_bert_based() {
140            if matches!(
141                self.config.model_type,
142                EmbeddingModelType::StaticSimilarityMRL
143            ) {
144                warn!(
145                    "Using StaticHashBackend for model_type {:?} — semantic search will NOT \
146                     work correctly! Pick PotionMultilingual (default) or a BERT variant.",
147                    self.config.model_type
148                );
149            }
150            return self.backend.process_batch(texts).await;
151        }
152
153        info!(
154            "Using BERT embeddings for model_type: {:?}, encoding {} texts",
155            self.config.model_type,
156            texts.len()
157        );
158
159        let total_start_time = std::time::Instant::now();
160        let result = self.encode_batch_with_controls(texts.clone()).await;
161
162        let total_time = total_start_time.elapsed();
163        debug!("Encoded {} texts in {:?}", texts.len(), total_time);
164
165        result
166    }
167
168    /// BERT path with concurrency, timeout, and adaptive batching.
169    async fn encode_batch_with_controls(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
170        let _permit = match self.concurrency_controller.try_acquire() {
171            Some(permit) => permit,
172            None => self.concurrency_controller.acquire().await?,
173        };
174
175        let batch_size = if self.config.adaptive_batching {
176            self.get_adaptive_batch_size(texts.len()).await
177        } else {
178            self.current_batch_size()
179        };
180
181        let mut all_embeddings = Vec::new();
182
183        for chunk in texts.chunks(batch_size) {
184            let start_time = std::time::Instant::now();
185
186            let batch_result = timeout(
187                Duration::from_secs(self.config.operation_timeout_secs),
188                self.backend.process_batch(chunk.to_vec()),
189            )
190            .await;
191
192            match batch_result {
193                Ok(Ok(batch_embeddings)) => {
194                    all_embeddings.extend(batch_embeddings);
195                    let time_ms = start_time.elapsed().as_millis() as f64;
196                    self.update_batch_performance(chunk.len(), time_ms, 1.0);
197                }
198                Ok(Err(e)) => {
199                    error!("Batch processing failed: {}", e);
200                    self.update_batch_performance(
201                        chunk.len(),
202                        start_time.elapsed().as_millis() as f64,
203                        0.0,
204                    );
205                    return Err(e);
206                }
207                Err(_) => {
208                    error!("Batch processing timed out");
209                    return Err(anyhow::anyhow!(
210                        "Batch processing timed out after {} seconds",
211                        self.config.operation_timeout_secs
212                    ));
213                }
214            }
215        }
216
217        Ok(all_embeddings)
218    }
219
220    /// Get adaptive batch size based on recent performance history.
221    async fn get_adaptive_batch_size(&self, text_count: usize) -> usize {
222        let base_size = self.current_batch_size();
223
224        if text_count <= base_size {
225            return text_count;
226        }
227
228        let recent_performance: Vec<f64> = self
229            .batch_performance_cache
230            .iter()
231            .take(10)
232            .map(|entry| *entry.value())
233            .collect();
234
235        let avg_performance = if recent_performance.is_empty() {
236            0.8 // Default success rate
237        } else {
238            recent_performance.iter().sum::<f64>() / recent_performance.len() as f64
239        };
240
241        if avg_performance > 0.9 {
242            (base_size as f64 * 1.2) as usize
243        } else if avg_performance < 0.7 {
244            (base_size as f64 * 0.8) as usize
245        } else {
246            base_size
247        }
248    }
249
250    /// Update batch performance stats (atomic CAS loop on current_batch_size).
251    fn update_batch_performance(&self, batch_size: usize, time_ms: f64, success_rate: f64) {
252        let metric = success_rate / (time_ms / batch_size as f64);
253        self.batch_performance_cache.insert(batch_size, metric);
254
255        loop {
256            let current = self.current_batch_size.load(Ordering::Acquire);
257
258            let new_size = if success_rate > 0.9 && time_ms < 1000.0 {
259                (current as f64 * 1.1) as usize
260            } else if success_rate < 0.7 || time_ms > 2000.0 {
261                (current as f64 * 0.9) as usize
262            } else {
263                return; // No change needed
264            };
265
266            let clamped = new_size.clamp(8, 256);
267
268            match self.current_batch_size.compare_exchange_weak(
269                current,
270                clamped,
271                Ordering::AcqRel,
272                Ordering::Relaxed,
273            ) {
274                Ok(_) => return,
275                Err(_) => {
276                    std::hint::spin_loop();
277                    continue;
278                }
279            }
280        }
281    }
282
283    /// Get current concurrency load.
284    pub fn current_concurrency_load(&self) -> usize {
285        self.concurrency_controller.current_load()
286    }
287
288    /// Get concurrency stats: `(current_load, max_capacity)`.
289    pub fn get_concurrency_stats(&self) -> (usize, usize) {
290        (
291            self.concurrency_controller.current_load(),
292            self.concurrency_controller.max_capacity(),
293        )
294    }
295}