use crate::backend::{BackendKind, EmbeddingBackend};
use crate::batch::{normalize_embeddings, BatchProcessor};
use crate::error::{InferenceError, Result};
use crate::models::ModelConfig;
use async_trait::async_trait;
use candle_core::quantized::gguf_file;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{info, instrument};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuantLevel {
Q8_0,
Q5KM,
Q4KM,
}
impl QuantLevel {
pub fn from_env() -> Self {
match std::env::var("DAKERA_QUANT_LEVEL")
.ok()
.as_deref()
.map(str::to_lowercase)
.as_deref()
{
Some("q4_k_m") | Some("q4km") => QuantLevel::Q4KM,
Some("q5_k_m") | Some("q5km") => QuantLevel::Q5KM,
_ => QuantLevel::Q8_0,
}
}
pub fn gguf_filename(&self) -> &'static str {
match self {
QuantLevel::Q8_0 => "bge-large-en-v1.5-q8_0.gguf",
QuantLevel::Q5KM => "bge-large-en-v1.5-q5_k_m.gguf",
QuantLevel::Q4KM => "bge-large-en-v1.5-q4_k_m.gguf",
}
}
}
pub struct GgufBackend {
model: Arc<BertModel>,
processor: Arc<BatchProcessor>,
device: Device,
dimension: usize,
quant: QuantLevel,
}
impl GgufBackend {
#[instrument(skip_all, fields(model = %config.model))]
pub async fn new(config: &ModelConfig) -> Result<Self> {
let config = config.clone();
let quant = QuantLevel::from_env();
info!(
"Initialising GgufBackend: model={}, quant={:?}",
config.model, quant
);
let device = crate::backend::candle_backend::CandleBackend::resolve_device(config.use_gpu)?;
let model_id = config.model.model_id();
let gguf_repo = config.model.gguf_repo_id();
let gguf_filename = quant.gguf_filename();
let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model_id)?;
let gguf_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(gguf_repo)?;
if !cache_dir.join("tokenizer.json").exists() {
let id = model_id.to_string();
let c = cache_dir.clone();
tokio::task::spawn_blocking(move || {
crate::backend::onnx::OnnxBackend::download_hf_file(&id, "tokenizer.json", &c)
.map_err(InferenceError::HubError)
})
.await
.map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
}
if !cache_dir.join("config.json").exists() {
let id = model_id.to_string();
let c = cache_dir.clone();
tokio::task::spawn_blocking(move || {
crate::backend::onnx::OnnxBackend::download_hf_file(&id, "config.json", &c)
.map_err(InferenceError::HubError)
})
.await
.map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
}
let gguf_path = gguf_dir.join(gguf_filename);
if !gguf_path.exists() {
let repo = gguf_repo.to_string();
let fname = gguf_filename.to_string();
let c = gguf_dir.clone();
tokio::task::spawn_blocking(move || {
crate::backend::onnx::OnnxBackend::download_hf_file(&repo, &fname, &c)
.map_err(InferenceError::HubError)
})
.await
.map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
}
let tokenizer = Tokenizer::from_file(cache_dir.join("tokenizer.json"))
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let bert_config: BertConfig = {
let f = std::fs::File::open(cache_dir.join("config.json"))?;
serde_json::from_reader(f)
.map_err(|e| InferenceError::ModelLoadError(format!("config.json: {e}")))?
};
let model = tokio::task::spawn_blocking(move || {
let mut file = std::fs::File::open(&gguf_path)
.map_err(|e| InferenceError::ModelLoadError(format!("open GGUF: {e}")))?;
let gguf_content = gguf_file::Content::read(&mut file)
.map_err(|e| InferenceError::ModelLoadError(format!("parse GGUF: {e}")))?;
let vb = VarBuilder::from_gguf_buffer(
&std::fs::read(&gguf_path)
.map_err(|e| InferenceError::ModelLoadError(format!("read GGUF: {e}")))?,
&device,
)
.map_err(|e| InferenceError::ModelLoadError(format!("VarBuilder: {e}")))?;
let _ = gguf_content; BertModel::load(vb, &bert_config)
.map_err(|e| InferenceError::ModelLoadError(format!("BertModel::load: {e}")))
})
.await
.map_err(|e| InferenceError::ModelLoadError(format!("GGUF load panicked: {e}")))??;
let dimension = bert_config.hidden_size;
let processor = Arc::new(BatchProcessor::new(
tokenizer,
config.model,
config.max_batch_size,
));
info!(
"GgufBackend ready: dimension={}, quant={:?}",
dimension, quant
);
Ok(Self {
model: Arc::new(model),
processor,
device,
dimension,
quant,
})
}
pub fn quant_level(&self) -> QuantLevel {
self.quant
}
}
#[async_trait]
impl EmbeddingBackend for GgufBackend {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let model = Arc::clone(&self.model);
let processor = Arc::clone(&self.processor);
let texts_owned = texts.to_vec();
let device = self.device.clone();
let dim = self.dimension;
tokio::task::spawn_blocking(move || {
let prepared = processor.tokenize_batch(&texts_owned)?;
let batch_size = prepared.batch_size;
let seq_len = prepared.seq_len;
let to_tensor = |data: Vec<i64>| -> candle_core::Result<Tensor> {
Tensor::from_vec(data, (batch_size, seq_len), &device)
};
let input_ids = to_tensor(prepared.input_ids)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let attention_mask = to_tensor(prepared.attention_mask.clone())
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let token_type_ids = to_tensor(prepared.token_type_ids)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let hidden = model
.forward(&input_ids, &token_type_ids, Some(&attention_mask))
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_f32 = attention_mask
.to_dtype(DType::F32)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_sum = mask_f32
.sum(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let hidden_f32 = hidden
.to_dtype(DType::F32)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_expanded = mask_f32
.unsqueeze(2)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let masked = (hidden_f32 * mask_expanded)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let summed = masked
.sum(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mask_sum_exp = mask_sum
.unsqueeze(1)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let pooled = (summed / mask_sum_exp)
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let data = pooled
.to_vec2::<f32>()
.map_err(|e| InferenceError::InferenceError(e.to_string()))?;
let mut embeddings: Vec<Vec<f32>> = data
.into_iter()
.map(|r| r.into_iter().take(dim).collect())
.collect();
normalize_embeddings(&mut embeddings);
Ok(embeddings)
})
.await
.map_err(|e| InferenceError::InferenceError(format!("GgufBackend task panicked: {e}")))?
}
fn dimension(&self) -> usize {
self.dimension
}
fn backend_kind(&self) -> BackendKind {
BackendKind::Gguf
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quant_level_from_env_default() {
std::env::remove_var("DAKERA_QUANT_LEVEL");
assert_eq!(QuantLevel::from_env(), QuantLevel::Q8_0);
}
#[test]
fn test_quant_level_from_env_q4() {
std::env::set_var("DAKERA_QUANT_LEVEL", "q4_k_m");
assert_eq!(QuantLevel::from_env(), QuantLevel::Q4KM);
std::env::remove_var("DAKERA_QUANT_LEVEL");
}
#[test]
fn test_quant_level_from_env_q5() {
std::env::set_var("DAKERA_QUANT_LEVEL", "q5_k_m");
assert_eq!(QuantLevel::from_env(), QuantLevel::Q5KM);
std::env::remove_var("DAKERA_QUANT_LEVEL");
}
#[test]
fn test_quant_level_gguf_filenames_distinct() {
assert_ne!(
QuantLevel::Q8_0.gguf_filename(),
QuantLevel::Q4KM.gguf_filename()
);
assert_ne!(
QuantLevel::Q5KM.gguf_filename(),
QuantLevel::Q4KM.gguf_filename()
);
}
#[test]
fn test_backend_kind_gguf() {
assert_eq!(BackendKind::Gguf.to_string(), "gguf");
}
}