use crate::backend::{BackendKind, EmbeddingBackend};
use crate::error::{InferenceError, Result};
use crate::models::ModelConfig;
use async_trait::async_trait;
use std::sync::Arc;
use tokenizers::Tokenizer;
use tracing::{debug, info, instrument};
pub struct StaticBackend {
vocab_matrix: Arc<Vec<f32>>,
tokenizer: Arc<Tokenizer>,
dimension: usize,
vocab_size: usize,
}
impl StaticBackend {
#[instrument(skip_all)]
pub async fn new(config: &ModelConfig) -> Result<Self> {
let config = config.clone();
info!("Initialising StaticBackend (Model2Vec)");
let dim = Self::model2vec_dimension();
let model_id = config.model.model_id();
let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model_id)?;
if !cache_dir.join("tokenizer.json").exists() {
let model_id_owned = model_id.to_string();
let cache_dir_clone = cache_dir.clone();
tokio::task::spawn_blocking(move || {
crate::backend::onnx::OnnxBackend::download_hf_file(
&model_id_owned,
"tokenizer.json",
&cache_dir_clone,
)
.map_err(InferenceError::HubError)
})
.await
.map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
}
let tokenizer_path = cache_dir.join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let vocab_matrix = Self::load_vocab_matrix(&config, dim).await?;
let vocab_size = vocab_matrix.len() / dim;
info!(
"StaticBackend ready: vocab_size={}, dimension={}",
vocab_size, dim
);
Ok(Self {
vocab_matrix: Arc::new(vocab_matrix),
tokenizer: Arc::new(tokenizer),
dimension: dim,
vocab_size,
})
}
pub fn from_matrix(matrix: Vec<f32>, tokenizer: Tokenizer, dimension: usize) -> Result<Self> {
if !matrix.len().is_multiple_of(dimension) {
return Err(InferenceError::InvalidInput(format!(
"vocab_matrix length {} is not divisible by dimension {}",
matrix.len(),
dimension
)));
}
let vocab_size = matrix.len() / dimension;
Ok(Self {
vocab_matrix: Arc::new(matrix),
tokenizer: Arc::new(tokenizer),
dimension,
vocab_size,
})
}
pub fn model2vec_dimension() -> usize {
std::env::var("DAKERA_MRL_DIM")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|&d| d > 0)
.unwrap_or(256)
}
#[instrument(skip(self, text), fields(text_len = text.len()))]
fn embed_single(&self, text: &str) -> Vec<f32> {
let encoding = match self.tokenizer.encode(text, false) {
Ok(enc) => enc,
Err(_) => return vec![0.0; self.dimension],
};
let ids = encoding.get_ids();
if ids.is_empty() {
return vec![0.0; self.dimension];
}
let mut result = vec![0.0f32; self.dimension];
let mut valid_tokens = 0usize;
for &id in ids {
let idx = id as usize;
if idx >= self.vocab_size {
continue;
}
let offset = idx * self.dimension;
let row = &self.vocab_matrix[offset..offset + self.dimension];
for (r, v) in result.iter_mut().zip(row.iter()) {
*r += v;
}
valid_tokens += 1;
}
if valid_tokens == 0 {
return vec![0.0; self.dimension];
}
let n = valid_tokens as f32;
for v in result.iter_mut() {
*v /= n;
}
let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
for v in result.iter_mut() {
*v /= norm;
}
result
}
async fn load_vocab_matrix(config: &ModelConfig, _dim: usize) -> Result<Vec<f32>> {
let model2vec_repo = config.model.model2vec_repo_id();
let cache_dir = crate::backend::onnx::OnnxBackend::model_cache_dir(model2vec_repo)?;
let matrix_path = cache_dir.join("vocab_matrix.bin");
if !matrix_path.exists() {
info!("Downloading Model2Vec vocab matrix from {}", model2vec_repo);
let repo = model2vec_repo.to_string();
let cache = cache_dir.clone();
tokio::task::spawn_blocking(move || {
crate::backend::onnx::OnnxBackend::download_hf_file(
&repo,
"vocab_matrix.bin",
&cache,
)
.map_err(InferenceError::HubError)
})
.await
.map_err(|e| InferenceError::HubError(format!("Download panicked: {e}")))??;
}
info!("Loading vocab matrix from {:?}", matrix_path);
let bytes = std::fs::read(&matrix_path)?;
if bytes.len() % 4 != 0 {
return Err(InferenceError::ModelLoadError(format!(
"vocab_matrix.bin size {} is not a multiple of 4 bytes",
bytes.len()
)));
}
let floats: Vec<f32> = bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
debug!("Vocab matrix loaded: {} f32 values", floats.len());
Ok(floats)
}
}
#[async_trait]
impl EmbeddingBackend for StaticBackend {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let results: Vec<Vec<f32>> = texts.iter().map(|t| self.embed_single(t)).collect();
Ok(results)
}
fn dimension(&self) -> usize {
self.dimension
}
fn backend_kind(&self) -> BackendKind {
BackendKind::Static
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokenizers::models::wordlevel::WordLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace;
fn make_test_tokenizer(words: &[&str]) -> Tokenizer {
let mut vocab = std::collections::HashMap::new();
for (i, w) in words.iter().enumerate() {
vocab.insert(w.to_string(), i as u32);
}
let model = WordLevel::builder()
.vocab(vocab)
.unk_token("[UNK]".to_string())
.build()
.unwrap();
let mut tok = Tokenizer::new(model);
tok.with_pre_tokenizer(Some(Whitespace {}));
tok
}
fn make_identity_matrix(vocab_size: usize, dim: usize) -> Vec<f32> {
let mut m = vec![0.0f32; vocab_size * dim];
for i in 0..vocab_size {
m[i * dim + (i % dim)] = 1.0;
}
m
}
#[test]
fn test_static_backend_from_matrix_dimension() {
let words = ["[UNK]", "hello", "world", "test", "foo"];
let tok = make_test_tokenizer(&words);
let matrix = make_identity_matrix(5, 4);
let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
assert_eq!(backend.dimension(), 4);
}
#[test]
fn test_static_backend_from_matrix_vocab_size() {
let words = ["[UNK]", "a", "b", "c"];
let tok = make_test_tokenizer(&words);
let matrix = make_identity_matrix(4, 8);
let backend = StaticBackend::from_matrix(matrix, tok, 8).unwrap();
assert_eq!(backend.vocab_size, 4);
}
#[test]
fn test_static_backend_kind() {
let words = ["[UNK]", "hello"];
let tok = make_test_tokenizer(&words);
let matrix = vec![0.0f32; 2 * 4];
let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
assert_eq!(backend.backend_kind(), BackendKind::Static);
}
#[test]
fn test_static_embed_empty_text_returns_zeros() {
let words = ["[UNK]", "hello"];
let tok = make_test_tokenizer(&words);
let matrix = vec![1.0f32; 2 * 4]; let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
let result = backend.embed_single("");
assert_eq!(result.len(), 4);
assert!(result.iter().all(|&v| v.abs() < 1e-6));
}
#[test]
fn test_static_embed_single_token_normalized() {
let words = ["[UNK]", "hello", "world"];
let tok = make_test_tokenizer(&words);
let mut matrix = vec![0.0f32; 3 * 4];
matrix[4] = 1.0; let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
let emb = backend.embed_single("hello");
assert_eq!(emb.len(), 4);
assert!((emb[0] - 1.0).abs() < 1e-5);
assert!(emb[1].abs() < 1e-5);
}
#[test]
fn test_static_embed_invalid_matrix_dimension_error() {
let words = ["[UNK]", "hello"];
let tok = make_test_tokenizer(&words);
let matrix = vec![1.0f32; 5];
let result = StaticBackend::from_matrix(matrix, tok, 4);
assert!(result.is_err());
}
#[test]
fn test_model2vec_dimension_default() {
std::env::remove_var("DAKERA_MRL_DIM");
assert_eq!(StaticBackend::model2vec_dimension(), 256);
}
#[tokio::test]
async fn test_static_embed_batch_empty() {
let words = ["[UNK]", "hello"];
let tok = make_test_tokenizer(&words);
let matrix = vec![0.0f32; 2 * 4];
let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
let result = backend.embed_batch(&[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_static_embed_batch_multiple() {
let words = ["[UNK]", "hello", "world"];
let tok = make_test_tokenizer(&words);
let matrix = make_identity_matrix(3, 4);
let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
let texts = vec!["hello".to_string(), "world".to_string()];
let results = backend.embed_batch(&texts).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].len(), 4);
assert_eq!(results[1].len(), 4);
}
#[tokio::test]
async fn test_static_embed_batch_preserves_order() {
let words = ["[UNK]", "hello", "world"];
let tok = make_test_tokenizer(&words);
let mut matrix = vec![0.0f32; 3 * 4];
matrix[4] = 1.0; matrix[9] = 1.0; let backend = StaticBackend::from_matrix(matrix, tok, 4).unwrap();
let texts = vec!["hello".to_string(), "world".to_string()];
let results = backend.embed_batch(&texts).await.unwrap();
assert!(results[0][0] > results[0][1]);
assert!(results[1][1] > results[1][0]);
}
}