use crate::backend::{BackendKind, EmbeddingBackend};
use crate::batch::{normalize_embeddings, BatchProcessor};
use crate::error::{InferenceError, Result};
use crate::models::ModelConfig;
use async_trait::async_trait;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{info, instrument};
pub struct CandleBackend {
model: Arc<BertModel>,
processor: Arc<BatchProcessor>,
device: Device,
dimension: usize,
}
impl CandleBackend {
#[instrument(skip_all, fields(model = %config.model))]
pub async fn new(config: &ModelConfig) -> Result<Self> {
let config = config.clone();
info!("Initialising CandleBackend: model={}", config.model);
let device = Self::resolve_device(config.use_gpu)?;
info!("Candle device: {:?}", device);
let model_id = config.model.model_id();
let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model_id)?;
let files = ["tokenizer.json", "config.json", "model.safetensors"];
for filename in &files {
if !cache_dir.join(filename).exists() {
let model_id_owned = model_id.to_string();
let cache = cache_dir.clone();
let f = filename.to_string();
tokio::task::spawn_blocking(move || {
crate::backend::onnx::OnnxBackend::download_hf_file(&model_id_owned, &f, &cache)
.map_err(InferenceError::HubError)
})
.await
.map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
}
}
let tokenizer_path = cache_dir.join("tokenizer.json");
let config_path = cache_dir.join("config.json");
let model_path = cache_dir.join("model.safetensors");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let bert_config: BertConfig = {
let f = std::fs::File::open(&config_path)?;
serde_json::from_reader(f)
.map_err(|e| InferenceError::ModelLoadError(format!("config.json parse: {e}")))?
};
let dtype = match &device {
Device::Cpu => DType::F32,
_ => DType::BF16,
};
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_path], dtype, &device)
.map_err(|e| InferenceError::ModelLoadError(format!("VarBuilder: {e}")))?
};
let model = BertModel::load(vb, &bert_config)
.map_err(|e| InferenceError::ModelLoadError(format!("BertModel::load: {e}")))?;
let dimension = bert_config.hidden_size;
let processor = Arc::new(BatchProcessor::new(
tokenizer,
config.model,
config.max_batch_size,
));
info!("CandleBackend ready: dimension={}", dimension);
Ok(Self {
model: Arc::new(model),
processor,
device,
dimension,
})
}
pub fn resolve_device(use_gpu: bool) -> Result<Device> {
let use_gpu = std::env::var("DAKERA_USE_GPU")
.map(|v| v == "1")
.unwrap_or(use_gpu);
if use_gpu {
#[cfg(feature = "cuda")]
{
match Device::new_cuda(0) {
Ok(d) => return Ok(d),
Err(e) => {
tracing::warn!("CUDA device unavailable ({}), falling back to CPU", e);
}
}
}
#[cfg(feature = "metal")]
{
match Device::new_metal(0) {
Ok(d) => return Ok(d),
Err(e) => {
tracing::warn!("Metal device unavailable ({}), falling back to CPU", e);
}
}
}
}
Ok(Device::Cpu)
}
fn embed_batch_sync(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let prepared = self.processor.tokenize_batch(texts)?;
let batch_size = prepared.batch_size;
let seq_len = prepared.seq_len;
let to_tensor = |data: Vec<i64>| -> candle_core::Result<Tensor> {
Tensor::from_vec(data, (batch_size, seq_len), &self.device)
};
let input_ids = to_tensor(prepared.input_ids)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let attention_mask = to_tensor(prepared.attention_mask.clone())
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let token_type_ids = to_tensor(prepared.token_type_ids)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let hidden = self
.model
.forward(&input_ids, &token_type_ids, Some(&attention_mask))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_f32 = attention_mask
.to_dtype(DType::F32)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_sum = mask_f32
.sum(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let hidden_f32 = hidden
.to_dtype(DType::F32)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_expanded = mask_f32
.unsqueeze(2)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let masked = (hidden_f32 * mask_expanded)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let summed = masked
.sum(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_sum_expanded = mask_sum
.unsqueeze(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let pooled = (summed / mask_sum_expanded)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let flat: Vec<f32> = pooled
.to_vec2::<f32>()
.map_err(|e| InferenceError::InferenceError(e.to_string()))?
.into_iter()
.flatten()
.collect();
let hidden_size = flat.len() / batch_size;
let mut embeddings: Vec<Vec<f32>> = flat.chunks(hidden_size).map(|c| c.to_vec()).collect();
normalize_embeddings(&mut embeddings);
Ok(embeddings)
}
}
#[async_trait]
impl EmbeddingBackend for CandleBackend {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let model = Arc::clone(&self.model);
let processor = Arc::clone(&self.processor);
let texts_owned = texts.to_vec();
let device = self.device.clone();
let dim = self.dimension;
let use_gpu = !matches!(self.device, Device::Cpu);
let gpu_permit = if use_gpu {
Some(
std::sync::Arc::clone(&crate::GPU_INFERENCE_SEMAPHORE)
.acquire_owned()
.await
.map_err(|_| {
InferenceError::InferenceError(
"GPU inference semaphore unexpectedly closed".to_string(),
)
})?,
)
} else {
None
};
tokio::task::spawn_blocking(move || {
let _gpu_permit = gpu_permit; let backend = CandleBackendSync {
model,
processor,
device,
dimension: dim,
};
backend.embed(texts_owned)
})
.await
.map_err(|e| InferenceError::InferenceError(format!("Candle task panicked: {e}")))?
}
fn dimension(&self) -> usize {
self.dimension
}
fn backend_kind(&self) -> BackendKind {
BackendKind::Candle
}
}
struct CandleBackendSync {
model: Arc<BertModel>,
processor: Arc<BatchProcessor>,
device: Device,
dimension: usize,
}
impl CandleBackendSync {
fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let prepared = self.processor.tokenize_batch(&texts)?;
let batch_size = prepared.batch_size;
let seq_len = prepared.seq_len;
let to_tensor = |data: Vec<i64>| -> candle_core::Result<Tensor> {
Tensor::from_vec(data, (batch_size, seq_len), &self.device)
};
let input_ids = to_tensor(prepared.input_ids)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let attention_mask = to_tensor(prepared.attention_mask.clone())
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let token_type_ids = to_tensor(prepared.token_type_ids)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let hidden = self
.model
.forward(&input_ids, &token_type_ids, Some(&attention_mask))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_f32 = attention_mask
.to_dtype(DType::F32)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_sum = mask_f32
.sum(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let hidden_f32 = hidden
.to_dtype(DType::F32)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_expanded = mask_f32
.unsqueeze(2)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let masked = (hidden_f32 * mask_expanded)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let summed = masked
.sum(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_sum_expanded = mask_sum
.unsqueeze(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let pooled = (summed / mask_sum_expanded)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let data = pooled
.to_vec2::<f32>()
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let hidden_size = self.dimension;
let mut embeddings: Vec<Vec<f32>> = data
.into_iter()
.map(|row| {
if row.len() == hidden_size {
row
} else {
row.into_iter().take(hidden_size).collect()
}
})
.collect();
normalize_embeddings(&mut embeddings);
Ok(embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_candle_backend_kind() {
assert_eq!(BackendKind::Candle.to_string(), "candle");
}
#[test]
fn test_resolve_device_cpu_without_gpu_flag() {
std::env::remove_var("DAKERA_USE_GPU");
let device = CandleBackend::resolve_device(false).unwrap();
matches!(device, Device::Cpu);
}
#[test]
fn test_resolve_device_cpu_fallback_when_no_gpu_hardware() {
let device = CandleBackend::resolve_device(true);
assert!(device.is_ok());
}
#[test]
fn test_flash_attn_feature_gate() {
let flash_enabled = cfg!(feature = "flash-attn");
let _ = flash_enabled; }
}