use crate::SymbolId;
use crossbeam_channel::{Receiver, Sender, bounded};
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use super::SemanticSearchError;
use super::remote::{RemoteEmbedder, run_async};
pub enum EmbeddingBackend {
Local(EmbeddingPool),
Remote(Arc<RemoteEmbedder>),
}
impl EmbeddingBackend {
pub fn dimensions(&self) -> usize {
match self {
EmbeddingBackend::Local(pool) => pool.dimensions(),
EmbeddingBackend::Remote(r) => r.dim(),
}
}
pub fn log_usage_stats(&self) {
if let EmbeddingBackend::Local(pool) = self {
pool.log_usage_stats();
}
}
pub fn model_name(&self) -> &str {
match self {
EmbeddingBackend::Local(pool) => pool.model_name(),
EmbeddingBackend::Remote(_) => "remote",
}
}
pub fn embed_one(&self, text: &str) -> Result<Vec<f32>, SemanticSearchError> {
match self {
EmbeddingBackend::Local(pool) => pool.embed_one(text),
EmbeddingBackend::Remote(r) => {
let r = Arc::clone(r);
let text = text.to_string();
run_async(async move {
let results = r.embed(&[text]).await?;
results.into_iter().next().ok_or_else(|| {
SemanticSearchError::EmbeddingError("Remote embed returned empty".into())
})
})
}
}
}
pub fn embed_parallel(
&self,
items: &[(SymbolId, &str, &str)],
) -> Vec<(SymbolId, Vec<f32>, String)> {
match self {
EmbeddingBackend::Local(pool) => pool.embed_parallel(items),
EmbeddingBackend::Remote(r) => {
let r = Arc::clone(r);
let texts: Vec<String> = items.iter().map(|(_, t, _)| t.to_string()).collect();
let dim = r.dim();
let embeddings = run_async(async move { r.embed(&texts).await });
match embeddings {
Ok(embs) => embs
.into_iter()
.zip(items.iter())
.filter_map(|(emb, (id, _, lang))| {
if emb.len() == dim {
Some((*id, emb, (*lang).to_string()))
} else {
tracing::warn!(
target: "semantic",
"Remote dim mismatch for {}: expected {dim}, got {}",
id.to_u32(), emb.len()
);
None
}
})
.collect(),
Err(e) => {
tracing::error!(target: "semantic", "Remote embed_parallel failed: {e}");
Vec::new()
}
}
}
}
}
}
struct ModelInstance {
model: TextEmbedding,
id: usize,
}
pub struct EmbeddingPool {
model_sender: Sender<ModelInstance>,
model_receiver: Receiver<ModelInstance>,
pool_size: usize,
dimensions: usize,
model_name: String,
usage_counters: Vec<AtomicUsize>,
}
impl EmbeddingPool {
pub fn new(pool_size: usize, model: EmbeddingModel) -> Result<Self, SemanticSearchError> {
let pool_size = pool_size.max(1);
let (sender, receiver) = bounded(pool_size);
let cache_dir = crate::init::models_dir();
let model_name = crate::vector::model_to_string(&model);
tracing::info!(
target: "semantic",
"Initializing embedding pool: {pool_size} instances ({model_name})"
);
let mut dimensions = 0;
let usage_counters: Vec<AtomicUsize> =
(0..pool_size).map(|_| AtomicUsize::new(0)).collect();
for i in 0..pool_size {
let mut text_model = TextEmbedding::try_new(
InitOptions::new(model.clone())
.with_cache_dir(cache_dir.clone())
.with_show_download_progress(i == 0),
)
.map_err(|e| {
SemanticSearchError::ModelInitError(format!(
"Failed to initialize model instance {}: {}",
i + 1,
e
))
})?;
if i == 0 {
let test_embedding = text_model
.embed(vec!["test"], None)
.map_err(|e| SemanticSearchError::EmbeddingError(e.to_string()))?;
dimensions = test_embedding.into_iter().next().unwrap().len();
}
sender
.send(ModelInstance {
model: text_model,
id: i,
})
.expect("Pool channel should not be closed");
}
tracing::info!(
target: "semantic",
"Embedding pool ready: {pool_size} instances, {dimensions} dimensions"
);
Ok(Self {
model_sender: sender,
model_receiver: receiver,
pool_size,
dimensions,
model_name,
usage_counters,
})
}
pub fn with_size(pool_size: usize) -> Result<Self, SemanticSearchError> {
Self::new(pool_size, EmbeddingModel::AllMiniLML6V2)
}
fn acquire(&self) -> ModelInstance {
let instance = self
.model_receiver
.recv()
.expect("Pool should not be empty");
self.usage_counters[instance.id].fetch_add(1, Ordering::Relaxed);
instance
}
fn release(&self, instance: ModelInstance) {
let _ = self.model_sender.send(instance);
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
pub fn pool_size(&self) -> usize {
self.pool_size
}
pub fn model_name(&self) -> &str {
&self.model_name
}
pub fn embed_one(&self, text: &str) -> Result<Vec<f32>, SemanticSearchError> {
if text.trim().is_empty() {
return Err(SemanticSearchError::EmbeddingError(
"Empty text".to_string(),
));
}
let mut instance = self.acquire();
let result = instance
.model
.embed(vec![text], None)
.map_err(|e| SemanticSearchError::EmbeddingError(e.to_string()));
self.release(instance);
result.map(|mut v| v.remove(0))
}
pub fn log_usage_stats(&self) {
let counts: Vec<usize> = self
.usage_counters
.iter()
.map(|c| c.load(Ordering::Relaxed))
.collect();
let total: usize = counts.iter().sum();
if total > 0 {
let usage_str: Vec<String> = counts
.iter()
.enumerate()
.map(|(i, c)| format!("model[{i}]={c}"))
.collect();
tracing::info!(
target: "semantic",
"Embedding pool usage: {} (total: {total})",
usage_str.join(", ")
);
}
}
pub fn embed_parallel(
&self,
items: &[(SymbolId, &str, &str)],
) -> Vec<(SymbolId, Vec<f32>, String)> {
use rayon::prelude::*;
const BATCH_SIZE: usize = 64;
let valid_items: Vec<_> = items
.iter()
.filter(|(_, doc, _)| !doc.trim().is_empty())
.collect();
if valid_items.is_empty() {
return Vec::new();
}
let results: Vec<_> = valid_items
.chunks(BATCH_SIZE)
.par_bridge()
.flat_map(|batch| {
let texts: Vec<&str> = batch.iter().map(|(_, doc, _)| *doc).collect();
let mut instance = self.acquire();
let embeddings_result = instance.model.embed(texts.clone(), None);
self.release(instance);
match embeddings_result {
Ok(embeddings) => {
let mut results = Vec::with_capacity(batch.len());
for (item, embedding) in batch.iter().zip(embeddings.into_iter()) {
let (symbol_id, _, language) = *item;
if embedding.len() == self.dimensions {
results.push((*symbol_id, embedding, (*language).to_string()));
} else {
tracing::warn!(
target: "semantic",
"Dimension mismatch for {}: expected {}, got {}",
symbol_id.to_u32(),
self.dimensions,
embedding.len()
);
}
}
results
}
Err(e) => {
tracing::warn!(target: "semantic", "Batch embedding failed: {e}");
Vec::new()
}
}
})
.collect();
self.log_usage_stats();
results
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "Downloads 86MB model - run with --ignored"]
fn test_pool_creation() {
let pool = EmbeddingPool::with_size(2).unwrap();
assert_eq!(pool.pool_size(), 2);
assert_eq!(pool.dimensions(), 384);
}
#[test]
#[ignore = "Downloads 86MB model - run with --ignored"]
fn test_parallel_embedding() {
let pool = EmbeddingPool::with_size(2).unwrap();
let items = vec![
(SymbolId::new(1).unwrap(), "Parse JSON data", "rust"),
(SymbolId::new(2).unwrap(), "Connect to database", "rust"),
];
let results = pool.embed_parallel(&items);
assert_eq!(results.len(), 2);
}
}