use anyhow::Result;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
use super::backend::EmbeddingBackend;
#[cfg(feature = "bert")]
use super::backends::BertBackend;
#[cfg(feature = "model2vec")]
use super::backends::Model2VecBackend;
use super::backends::StaticHashBackend;
use super::concurrency::ConcurrencyController;
use super::config::{EmbeddingConfig, EmbeddingModelType};
use super::pool::MemoryPool;
pub struct LocalEmbeddingEngine {
backend: Arc<dyn EmbeddingBackend>,
config: EmbeddingConfig,
current_batch_size: AtomicUsize,
batch_performance_cache: Arc<DashMap<usize, f64>>,
concurrency_controller: Arc<ConcurrencyController>,
}
impl LocalEmbeddingEngine {
pub async fn new(config: EmbeddingConfig) -> Result<Self> {
info!(
"Initializing embedding engine with model: {:?}",
config.model_type
);
let dimension = config.model_type.embedding_dimension();
let backend: Arc<dyn EmbeddingBackend> = if config.model_type.is_model2vec() {
#[cfg(feature = "model2vec")]
{
Arc::new(Model2VecBackend::load(config.model_type).await?)
}
#[cfg(not(feature = "model2vec"))]
{
return Err(anyhow::anyhow!(
"Model type {:?} requires the `model2vec` feature, which is disabled. \
Rebuild post-cortex-embeddings with `--features model2vec` or pick a \
different EmbeddingModelType.",
config.model_type
));
}
} else if config.model_type.is_bert_based() {
#[cfg(feature = "bert")]
{
Arc::new(BertBackend::load(config.model_type).await?)
}
#[cfg(not(feature = "bert"))]
{
return Err(anyhow::anyhow!(
"Model type {:?} requires the `bert` feature, which is disabled. \
Rebuild post-cortex-embeddings with `--features bert` or pick a \
different EmbeddingModelType.",
config.model_type
));
}
} else {
let pool = Arc::new(MemoryPool::new(config.memory_pool_size, dimension));
Arc::new(StaticHashBackend::new(dimension, pool))
};
let concurrency_controller =
Arc::new(ConcurrencyController::new(config.max_concurrent_ops));
Ok(Self {
backend,
current_batch_size: AtomicUsize::new(config.max_batch_size),
batch_performance_cache: Arc::new(DashMap::new()),
concurrency_controller,
config,
})
}
pub fn current_batch_size(&self) -> usize {
self.current_batch_size.load(Ordering::Relaxed)
}
pub fn embedding_dimension(&self) -> usize {
self.backend.embedding_dimension()
}
pub fn is_bert_based(&self) -> bool {
self.backend.is_bert_based()
}
pub async fn encode_text(&self, text: &str) -> Result<Vec<f32>> {
let embeddings = self.encode_batch(vec![text.to_string()]).await?;
embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No embeddings generated"))
}
pub async fn encode_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
if !self.backend.is_bert_based() {
if matches!(
self.config.model_type,
EmbeddingModelType::StaticSimilarityMRL
) {
warn!(
"Using StaticHashBackend for model_type {:?} — semantic search will NOT \
work correctly! Pick PotionMultilingual (default) or a BERT variant.",
self.config.model_type
);
}
return self.backend.process_batch(texts).await;
}
info!(
"Using BERT embeddings for model_type: {:?}, encoding {} texts",
self.config.model_type,
texts.len()
);
let total_start_time = std::time::Instant::now();
let result = self.encode_batch_with_controls(texts.clone()).await;
let total_time = total_start_time.elapsed();
debug!("Encoded {} texts in {:?}", texts.len(), total_time);
result
}
async fn encode_batch_with_controls(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
let _permit = match self.concurrency_controller.try_acquire() {
Some(permit) => permit,
None => self.concurrency_controller.acquire().await?,
};
let batch_size = if self.config.adaptive_batching {
self.get_adaptive_batch_size(texts.len()).await
} else {
self.current_batch_size()
};
let mut all_embeddings = Vec::new();
for chunk in texts.chunks(batch_size) {
let start_time = std::time::Instant::now();
let batch_result = timeout(
Duration::from_secs(self.config.operation_timeout_secs),
self.backend.process_batch(chunk.to_vec()),
)
.await;
match batch_result {
Ok(Ok(batch_embeddings)) => {
all_embeddings.extend(batch_embeddings);
let time_ms = start_time.elapsed().as_millis() as f64;
self.update_batch_performance(chunk.len(), time_ms, 1.0);
}
Ok(Err(e)) => {
error!("Batch processing failed: {}", e);
self.update_batch_performance(
chunk.len(),
start_time.elapsed().as_millis() as f64,
0.0,
);
return Err(e);
}
Err(_) => {
error!("Batch processing timed out");
return Err(anyhow::anyhow!(
"Batch processing timed out after {} seconds",
self.config.operation_timeout_secs
));
}
}
}
Ok(all_embeddings)
}
async fn get_adaptive_batch_size(&self, text_count: usize) -> usize {
let base_size = self.current_batch_size();
if text_count <= base_size {
return text_count;
}
let recent_performance: Vec<f64> = self
.batch_performance_cache
.iter()
.take(10)
.map(|entry| *entry.value())
.collect();
let avg_performance = if recent_performance.is_empty() {
0.8 } else {
recent_performance.iter().sum::<f64>() / recent_performance.len() as f64
};
if avg_performance > 0.9 {
(base_size as f64 * 1.2) as usize
} else if avg_performance < 0.7 {
(base_size as f64 * 0.8) as usize
} else {
base_size
}
}
fn update_batch_performance(&self, batch_size: usize, time_ms: f64, success_rate: f64) {
let metric = success_rate / (time_ms / batch_size as f64);
self.batch_performance_cache.insert(batch_size, metric);
loop {
let current = self.current_batch_size.load(Ordering::Acquire);
let new_size = if success_rate > 0.9 && time_ms < 1000.0 {
(current as f64 * 1.1) as usize
} else if success_rate < 0.7 || time_ms > 2000.0 {
(current as f64 * 0.9) as usize
} else {
return; };
let clamped = new_size.clamp(8, 256);
match self.current_batch_size.compare_exchange_weak(
current,
clamped,
Ordering::AcqRel,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(_) => {
std::hint::spin_loop();
continue;
}
}
}
}
pub fn current_concurrency_load(&self) -> usize {
self.concurrency_controller.current_load()
}
pub fn get_concurrency_stats(&self) -> (usize, usize) {
(
self.concurrency_controller.current_load(),
self.concurrency_controller.max_capacity(),
)
}
}