use std::sync::Arc;
use async_trait::async_trait;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{Config as XLMRobertaConfig, XLMRobertaModel};
use solo_core::{Embedder, Embedding, EmbeddingDtype, Error, Result};
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer, TruncationParams};
use super::bge_m3::{BGE_M3_NAME, BGE_M3_VERSION, BgeM3Manifest};
pub struct BgeM3Inference {
model: XLMRobertaModel,
tokenizer: Arc<Tokenizer>,
device: Device,
dim: usize,
name: String,
version: String,
}
impl BgeM3Inference {
pub fn load(manifest: &BgeM3Manifest) -> Result<Self> {
let device = Device::Cpu;
let weights_path = &manifest.weights_path;
if !weights_path
.file_name()
.map(|n| n.to_string_lossy().ends_with(".safetensors"))
.unwrap_or(false)
{
return Err(Error::embedder(format!(
"BGE-M3 inference requires model.safetensors; got {weights_path:?}"
)));
}
let config_raw = std::fs::read_to_string(&manifest.config_path)
.map_err(|e| Error::embedder(format!("read config.json: {e}")))?;
let candle_cfg: XLMRobertaConfig = serde_json::from_str(&config_raw)
.map_err(|e| Error::embedder(format!("parse XLMRoberta Config from config.json: {e}")))?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path.as_path()], DType::F32, &device)
.map_err(|e| Error::embedder(format!("VarBuilder::from_mmaped_safetensors: {e}")))?
};
let model = XLMRobertaModel::new(&candle_cfg, vb)
.map_err(|e| Error::embedder(format!("XLMRobertaModel::new: {e}")))?;
let mut tokenizer = Tokenizer::from_file(&manifest.tokenizer_path)
.map_err(|e| Error::embedder(format!("Tokenizer::from_file: {e}")))?;
tokenizer
.with_padding(Some(PaddingParams {
strategy: PaddingStrategy::BatchLongest,
pad_id: candle_cfg.pad_token_id,
pad_token: "<pad>".to_string(),
..Default::default()
}))
.with_truncation(Some(TruncationParams {
max_length: candle_cfg.max_position_embeddings.saturating_sub(2).max(1),
..Default::default()
}))
.map_err(|e| Error::embedder(format!("tokenizer truncation config: {e}")))?;
Ok(Self {
model,
tokenizer: Arc::new(tokenizer),
device,
dim: candle_cfg.hidden_size,
name: manifest.config.model_type.clone(),
version: BGE_M3_VERSION.to_string(),
})
}
fn forward_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| Error::embedder(format!("tokenizer.encode_batch: {e}")))?;
let batch = encodings.len();
let seq_len = encodings.first().map(|e| e.len()).unwrap_or(0);
if seq_len == 0 {
return Err(Error::embedder("tokenizer produced empty sequences"));
}
let mut input_ids: Vec<u32> = Vec::with_capacity(batch * seq_len);
let mut attention_mask: Vec<u32> = Vec::with_capacity(batch * seq_len);
for enc in &encodings {
input_ids.extend_from_slice(enc.get_ids());
attention_mask.extend_from_slice(enc.get_attention_mask());
}
let token_type_ids = vec![0u32; batch * seq_len];
let input_ids = Tensor::from_vec(input_ids, (batch, seq_len), &self.device)
.map_err(|e| Error::embedder(format!("input_ids tensor: {e}")))?;
let attention_mask = Tensor::from_vec(attention_mask, (batch, seq_len), &self.device)
.map_err(|e| Error::embedder(format!("attention_mask tensor: {e}")))?;
let token_type_ids = Tensor::from_vec(token_type_ids, (batch, seq_len), &self.device)
.map_err(|e| Error::embedder(format!("token_type_ids tensor: {e}")))?;
let hidden = self
.model
.forward(&input_ids, &attention_mask, &token_type_ids, None, None, None)
.map_err(|e| Error::embedder(format!("XLMRobertaModel::forward: {e}")))?;
let cls = hidden
.narrow(1, 0, 1)
.and_then(|t| t.squeeze(1))
.map_err(|e| Error::embedder(format!("CLS pool: {e}")))?;
let norm = cls
.sqr()
.and_then(|t| t.sum_keepdim(1))
.and_then(|t| t.sqrt())
.map_err(|e| Error::embedder(format!("L2 norm: {e}")))?;
let normalized = cls
.broadcast_div(&norm)
.map_err(|e| Error::embedder(format!("normalize divide: {e}")))?;
let normalized = normalized
.to_dtype(DType::F32)
.map_err(|e| Error::embedder(format!("to_dtype f32: {e}")))?;
let rows = normalized
.to_vec2::<f32>()
.map_err(|e| Error::embedder(format!("to_vec2: {e}")))?;
Ok(rows)
}
}
#[async_trait]
impl Embedder for BgeM3Inference {
fn name(&self) -> &str {
&self.name
}
fn version(&self) -> &str {
&self.version
}
fn dim(&self) -> usize {
self.dim
}
fn dtype(&self) -> EmbeddingDtype {
EmbeddingDtype::F32
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let texts: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
let dim = self.dim;
let rows = tokio::task::block_in_place(|| {
let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
self.forward_batch(&refs)
})?;
let mut out: Vec<Embedding> = Vec::with_capacity(rows.len());
for row in rows {
if row.len() != dim {
return Err(Error::embedder(format!(
"BGE-M3 produced {} dims, expected {}",
row.len(),
dim
)));
}
let mut bytes = Vec::with_capacity(dim * 4);
for v in row {
bytes.extend_from_slice(&v.to_le_bytes());
}
out.push(Embedding {
dtype: EmbeddingDtype::F32,
dim,
data: bytes,
});
}
Ok(out)
}
}
pub fn canonical_identity() -> (&'static str, &'static str) {
(BGE_M3_NAME, BGE_M3_VERSION)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn load_with_garbage_safetensors_fails_clearly() {
use crate::embedder::bge_m3::{BgeM3Config, BgeM3Loader, BgeM3Manifest};
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
let cfg_path = dir.join("config.json");
let tok_path = dir.join("tokenizer.json");
let weights_path = dir.join("model.safetensors");
std::fs::write(
&cfg_path,
r#"{
"model_type": "xlm-roberta",
"hidden_size": 1024,
"layer_norm_eps": 1e-5,
"attention_probs_dropout_prob": 0.1,
"hidden_dropout_prob": 0.1,
"num_attention_heads": 16,
"position_embedding_type": "absolute",
"intermediate_size": 4096,
"hidden_act": "gelu",
"num_hidden_layers": 24,
"vocab_size": 250002,
"max_position_embeddings": 8194,
"type_vocab_size": 1,
"pad_token_id": 1
}"#,
)
.unwrap();
std::fs::write(&tok_path, r#"{"version":"1.0","model":{}}"#).unwrap();
std::fs::write(&weights_path, b"NOT A SAFETENSORS FILE").unwrap();
let manifest = BgeM3Manifest {
model_dir: dir.to_path_buf(),
config_path: cfg_path,
tokenizer_path: tok_path,
weights_path,
config: BgeM3Config {
model_type: "xlm-roberta".into(),
hidden_size: 1024,
vocab_size: 250002,
max_position_embeddings: 8194,
},
};
let _ = BgeM3Loader::open(dir).expect("loader sanity");
let err = match BgeM3Inference::load(&manifest) {
Ok(_) => panic!("expected Err"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("VarBuilder")
|| msg.contains("safetensors")
|| msg.contains("Tokenizer")
|| msg.contains("from_mmaped"),
"unexpected error: {msg}"
);
}
#[test]
fn rejects_pytorch_bin_weights() {
use crate::embedder::bge_m3::{BgeM3Config, BgeM3Manifest};
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
let cfg_path = dir.join("config.json");
let tok_path = dir.join("tokenizer.json");
let pth_path = dir.join("pytorch_model.bin");
std::fs::write(&cfg_path, "{}").unwrap();
std::fs::write(&tok_path, "{}").unwrap();
std::fs::write(&pth_path, b"fake").unwrap();
let manifest = BgeM3Manifest {
model_dir: dir.to_path_buf(),
config_path: cfg_path,
tokenizer_path: tok_path,
weights_path: pth_path,
config: BgeM3Config {
model_type: "xlm-roberta".into(),
hidden_size: 1024,
vocab_size: 0,
max_position_embeddings: 0,
},
};
let err = match BgeM3Inference::load(&manifest) {
Ok(_) => panic!("expected Err"),
Err(e) => e,
};
assert!(
err.to_string().contains("requires model.safetensors"),
"got: {err}"
);
}
}