use std::fs;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use solo_core::{Error, Result};
pub const BGE_M3_NAME: &str = "BAAI/bge-m3";
pub const BGE_M3_VERSION: &str = "v1";
pub const BGE_M3_DIM: usize = 1024;
const FILE_CONFIG: &str = "config.json";
const FILE_TOKENIZER: &str = "tokenizer.json";
const FILE_WEIGHTS_SAFETENSORS: &str = "model.safetensors";
const FILE_WEIGHTS_PYTORCH: &str = "pytorch_model.bin";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BgeM3Config {
pub model_type: String,
pub hidden_size: usize,
#[serde(default)]
pub vocab_size: usize,
#[serde(default)]
pub max_position_embeddings: usize,
}
impl BgeM3Config {
pub fn read(path: &Path) -> Result<Self> {
let raw = fs::read_to_string(path)
.map_err(|e| Error::embedder(format!("read {path:?}: {e}")))?;
serde_json::from_str(&raw).map_err(Error::Serde)
}
}
#[derive(Debug, Clone)]
pub struct BgeM3Manifest {
pub model_dir: PathBuf,
pub config_path: PathBuf,
pub tokenizer_path: PathBuf,
pub weights_path: PathBuf,
pub config: BgeM3Config,
}
#[derive(Debug, Clone)]
pub struct BgeM3Loader {
manifest: BgeM3Manifest,
}
impl BgeM3Loader {
pub fn open(model_dir: impl Into<PathBuf>) -> Result<Self> {
let model_dir = model_dir.into();
if !model_dir.is_dir() {
return Err(Error::embedder(format!(
"BGE-M3 model dir not found or not a directory: {model_dir:?}"
)));
}
let config_path = model_dir.join(FILE_CONFIG);
let tokenizer_path = model_dir.join(FILE_TOKENIZER);
for (name, p) in [
(FILE_CONFIG, &config_path),
(FILE_TOKENIZER, &tokenizer_path),
] {
if !p.is_file() {
return Err(Error::embedder(format!(
"BGE-M3 model dir is missing required file `{name}`: \
expected at {p:?}"
)));
}
}
let safetensors = model_dir.join(FILE_WEIGHTS_SAFETENSORS);
let pytorch = model_dir.join(FILE_WEIGHTS_PYTORCH);
let weights_path = if safetensors.is_file() {
safetensors
} else if pytorch.is_file() {
pytorch
} else {
return Err(Error::embedder(format!(
"BGE-M3 model dir has neither `{FILE_WEIGHTS_SAFETENSORS}` nor \
`{FILE_WEIGHTS_PYTORCH}`: expected one in {model_dir:?}"
)));
};
let config = BgeM3Config::read(&config_path)?;
if config.hidden_size == 0 {
return Err(Error::embedder(format!(
"BGE-M3 config.json reports hidden_size=0 — corrupt or wrong file? at {config_path:?}"
)));
}
Ok(Self {
manifest: BgeM3Manifest {
model_dir,
config_path,
tokenizer_path,
weights_path,
config,
},
})
}
pub fn manifest(&self) -> &BgeM3Manifest {
&self.manifest
}
pub fn dim(&self) -> usize {
self.manifest.config.hidden_size
}
pub fn into_embedder(self) -> Result<Box<dyn solo_core::Embedder>> {
let inference = crate::embedder::bge_m3_inference::BgeM3Inference::load(&self.manifest)?;
Ok(Box::new(inference))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn write(path: &Path, body: &[u8]) {
fs::write(path, body).unwrap();
}
fn fake_config_json(hidden_size: usize) -> String {
format!(
r#"{{
"model_type": "xlm-roberta",
"hidden_size": {hidden_size},
"vocab_size": 250002,
"max_position_embeddings": 8194
}}"#
)
}
fn fake_tokenizer_json() -> &'static str {
r#"{"version": "1.0", "model": {}}"#
}
#[test]
fn open_validates_complete_safetensors_dir() {
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
write(&dir.join("config.json"), fake_config_json(1024).as_bytes());
write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
write(&dir.join("model.safetensors"), b"<fake weights>");
let loader = BgeM3Loader::open(dir).unwrap();
assert_eq!(loader.dim(), 1024);
assert_eq!(loader.manifest().config.model_type, "xlm-roberta");
assert!(
loader
.manifest()
.weights_path
.ends_with("model.safetensors"),
"expected safetensors path, got {:?}",
loader.manifest().weights_path
);
}
#[test]
fn open_falls_back_to_pytorch_bin() {
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
write(&dir.join("config.json"), fake_config_json(1024).as_bytes());
write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
write(&dir.join("pytorch_model.bin"), b"<fake>");
let loader = BgeM3Loader::open(dir).unwrap();
assert!(
loader
.manifest()
.weights_path
.ends_with("pytorch_model.bin")
);
}
#[test]
fn open_rejects_missing_config() {
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
write(&dir.join("model.safetensors"), b"<fake>");
let err = BgeM3Loader::open(dir).unwrap_err();
assert!(err.to_string().contains("config.json"), "got: {err}");
}
#[test]
fn open_rejects_missing_tokenizer() {
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
write(&dir.join("config.json"), fake_config_json(1024).as_bytes());
write(&dir.join("model.safetensors"), b"<fake>");
let err = BgeM3Loader::open(dir).unwrap_err();
assert!(err.to_string().contains("tokenizer.json"), "got: {err}");
}
#[test]
fn open_rejects_missing_weights() {
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
write(&dir.join("config.json"), fake_config_json(1024).as_bytes());
write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
let err = BgeM3Loader::open(dir).unwrap_err();
assert!(
err.to_string().contains("model.safetensors"),
"got: {err}"
);
}
#[test]
fn open_rejects_zero_hidden_size() {
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
write(&dir.join("config.json"), fake_config_json(0).as_bytes());
write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
write(&dir.join("model.safetensors"), b"<fake>");
let err = BgeM3Loader::open(dir).unwrap_err();
assert!(err.to_string().contains("hidden_size=0"), "got: {err}");
}
#[test]
fn open_rejects_nonexistent_dir() {
let err = BgeM3Loader::open("/definitely/not/here/solo/bge-m3").unwrap_err();
assert!(err.to_string().contains("not found"), "got: {err}");
}
#[test]
fn open_rejects_dir_that_is_a_file() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let err = BgeM3Loader::open(tmp.path()).unwrap_err();
assert!(err.to_string().contains("not a directory"), "got: {err}");
}
#[test]
fn into_embedder_with_garbage_weights_fails_clearly() {
let tmp = tempfile::TempDir::new().unwrap();
let dir = tmp.path();
write(&dir.join("config.json"), full_xlm_config_json(1024).as_bytes());
write(&dir.join("tokenizer.json"), fake_tokenizer_json().as_bytes());
write(&dir.join("model.safetensors"), b"<fake>");
let loader = BgeM3Loader::open(dir).unwrap();
let err = match loader.into_embedder() {
Ok(_) => panic!("expected Err but got Ok"),
Err(e) => e,
};
let msg = err.to_string();
assert!(
msg.contains("safetensors")
|| msg.contains("VarBuilder")
|| msg.contains("Tokenizer")
|| msg.contains("from_mmaped"),
"expected safetensors/VarBuilder/Tokenizer error, got: {msg}"
);
}
fn full_xlm_config_json(hidden_size: usize) -> String {
format!(
r#"{{
"model_type": "xlm-roberta",
"hidden_size": {hidden_size},
"layer_norm_eps": 1e-5,
"attention_probs_dropout_prob": 0.1,
"hidden_dropout_prob": 0.1,
"num_attention_heads": 16,
"position_embedding_type": "absolute",
"intermediate_size": 4096,
"hidden_act": "gelu",
"num_hidden_layers": 24,
"vocab_size": 250002,
"max_position_embeddings": 8194,
"type_vocab_size": 1,
"pad_token_id": 1
}}"#
)
}
}