#![cfg(feature = "embeddings")]
use std::{
collections::HashMap,
fs,
path::{Path, PathBuf},
process,
};
use mlxrs::{
Array, Error,
embeddings::{
EmbeddingModel, EmbeddingModelConfiguration, EmbeddingModelConstructor, EmbeddingModelOutput,
EmbeddingModelTypeRegistry, EmbeddingWeights, LoadedEmbeddingModel, PoolingStrategy, load,
remap_model_type,
},
error::{FileOp, RankMismatchPayload},
io,
};
fn temp_dir(name: &str) -> PathBuf {
let dir = std::env::temp_dir().join(format!("mlxrs_emb_load_{}_{}", process::id(), name));
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
dir
}
fn config_json(model_type: &str) -> String {
format!(r#"{{"model_type": "{model_type}", "hidden_size": 4, "vocab_size": 5}}"#)
}
struct MockEmbedding;
impl EmbeddingModel for MockEmbedding {
fn forward(
&self,
input_ids: &Array,
_attention_mask: &Array,
) -> Result<EmbeddingModelOutput, Error> {
let (batch, seq) = match input_ids.shape().as_slice() {
[b, s] => (*b, *s),
_ => {
let shape = input_ids.shape();
return Err(Error::RankMismatch(RankMismatchPayload::new(
"MockEmbedding::forward expects rank-2 (batch, seq) ids",
shape.len() as u32,
shape,
)));
}
};
let hidden = 4usize;
let data = vec![0.0_f32; batch * seq * hidden];
Ok(EmbeddingModelOutput::from_hidden_state(
Array::from_slice::<f32>(&data, &(batch, seq, hidden)).unwrap(),
))
}
}
fn mock_constructor() -> EmbeddingModelConstructor {
Box::new(
|loaded: &LoadedEmbeddingModel| -> Result<Box<dyn EmbeddingModel>, Error> {
assert!(!loaded.weights_ref().is_empty());
Ok(Box::new(MockEmbedding))
},
)
}
fn write_tokenizer(dir: &Path) {
use tokenizers::{
Tokenizer as HfTokenizer, models::wordlevel::WordLevel, pre_tokenizers::whitespace::Whitespace,
};
let vocab = [("a", 0u32), ("b", 1), ("c", 2)]
.iter()
.map(|(w, i)| ((*w).to_string(), *i))
.collect();
let wl = WordLevel::builder()
.vocab(vocab)
.unk_token("a".to_string())
.build()
.unwrap();
let mut hf = HfTokenizer::new(wl);
hf.with_pre_tokenizer(Some(Whitespace {}));
hf.save(dir.join("tokenizer.json"), false).unwrap();
}
fn write_model_dir(dir: &Path, model_type: &str) {
fs::write(dir.join("config.json"), config_json(model_type)).unwrap();
let mut weights: EmbeddingWeights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2usize, 2)).unwrap(),
);
io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(dir);
}
#[test]
fn load_produces_context_via_public_surface() {
let dir = temp_dir("ctx");
write_model_dir(&dir, "bert");
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let ctx = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
)
.expect("load should succeed");
assert_eq!(ctx.model_type, "bert");
assert!(ctx.pooling.is_none(), "no 1_Pooling/config.json written");
let ids = Array::from_slice::<i32>(&[0, 1, 2], &(1usize, 3)).unwrap();
let mask = Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(1usize, 3)).unwrap();
let out = ctx.model.forward(&ids, &mask).unwrap();
assert_eq!(out.last_hidden_state().shape(), vec![1, 3, 4]);
let tok_ids = ctx.tokenizer.encode("a b c", false).unwrap();
assert_eq!(tok_ids.len(), 3);
}
#[test]
fn load_parses_pooling_config_when_present() {
let dir = temp_dir("pooling");
write_model_dir(&dir, "bert");
let pooling_dir = dir.join("1_Pooling");
fs::create_dir_all(&pooling_dir).unwrap();
fs::write(
pooling_dir.join("config.json"),
r#"{"word_embedding_dimension": 4, "pooling_mode_cls_token": true}"#,
)
.unwrap();
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let ctx = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
)
.unwrap();
let pooling = ctx.pooling.expect("pooling config parsed");
assert_eq!(pooling.strategy(), PoolingStrategy::Cls);
assert_eq!(pooling.dimension(), Some(4));
}
#[test]
fn unknown_model_type_errors() {
let dir = temp_dir("unknown");
write_model_dir(&dir, "no_such_arch");
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let Err(err) = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
) else {
panic!("unknown model_type must error");
};
let msg = err.to_string();
assert!(msg.contains("unsupported model type"), "got: {msg}");
assert!(msg.contains("no_such_arch"), "should name the type: {msg}");
}
#[test]
fn missing_config_errors() {
let dir = temp_dir("noconfig");
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let Err(err) = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
) else {
panic!("missing config must error");
};
assert!(err.to_string().contains("config.json"), "got: {err}");
}
#[test]
fn empty_model_directory_errors_before_filesystem_root_scan() {
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
for config in [
EmbeddingModelConfiguration::from_directory(""),
EmbeddingModelConfiguration::from_id(""),
EmbeddingModelConfiguration::from_directory(PathBuf::new()),
] {
let Err(err) = load(&config, ®istry) else {
panic!("an empty model directory path must be a recoverable error, not a load");
};
assert!(
matches!(err, Error::EmptyInput(_)),
"expected EmptyInput error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("model directory path must not be empty"),
"the error should explain the empty-path rejection; got: {msg}"
);
assert!(
!msg.contains("config.json") && !msg.contains("no model weights"),
"the empty path must be rejected before config/shard resolution; got: {msg}"
);
}
}
#[test]
fn empty_tokenizer_source_errors() {
let model_dir = temp_dir("empty-tok-src");
write_model_dir(&model_dir, "bert");
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let config = EmbeddingModelConfiguration::from_directory(&model_dir).with_tokenizer_source("");
let Err(err) = load(&config, ®istry) else {
panic!("an empty tokenizer_source path must be a recoverable error");
};
assert!(
matches!(err, Error::EmptyInput(_)),
"expected EmptyInput error; got {err:?}"
);
assert!(
err
.to_string()
.contains("tokenizer directory path must not be empty"),
"the error should explain the empty tokenizer-path rejection; got: {err}"
);
}
#[test]
fn invalid_tokenizer_dir_surfaces_path_context_and_preserves_typed_source() {
let model_dir = temp_dir("invalid-tok-dir-model");
write_model_dir(&model_dir, "bert");
let tok_dir = temp_dir("invalid-tok-dir-tok");
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let config =
EmbeddingModelConfiguration::from_directory(&model_dir).with_tokenizer_source(&tok_dir);
let Err(err) = load(&config, ®istry) else {
panic!("expected FileIo error when tokenizer.json is missing in the tokenizer_source dir");
};
match err {
Error::FileIo(p) => {
assert_eq!(
p.path(),
tok_dir.as_path(),
"FileIo payload must carry the selected tokenizer directory, got {:?}",
p.path()
);
assert_eq!(p.op(), FileOp::Other("tokenizer_load"));
assert_eq!(p.context(), "embeddings load: tokenizer");
let inner_dyn = p
.inner()
.get_ref()
.expect("io::Error must carry the boxed mlxrs::Error inner via get_ref()");
let inner_msg = format!("{inner_dyn}");
assert!(
inner_msg.contains("load tokenizer.json"),
"the inner typed Error::Tokenizer Display must be reachable via \
io::Error::get_ref(); got inner message: {inner_msg:?}"
);
let downcast = inner_dyn
.downcast_ref::<Error>()
.expect("the io::Error get_ref must downcast to the original mlxrs::Error");
assert!(
downcast.is_tokenizer(),
"the preserved typed source must be Error::Tokenizer; got {downcast:?}"
);
}
other => panic!("expected Error::FileIo wrapping the typed Tokenizer source, got {other:?}"),
}
}
#[test]
fn separator_normalization_via_public_remap() {
assert_eq!(remap_model_type("xlm-roberta"), "xlm_roberta");
let dir = temp_dir("sep");
write_model_dir(&dir, "xlm-roberta");
let registry = EmbeddingModelTypeRegistry::new().with("xlm_roberta", mock_constructor());
let ctx = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
)
.unwrap();
assert_eq!(ctx.model_type, "xlm_roberta");
}
#[cfg(unix)]
fn try_write_one_tensor(path: &Path, key: &str) -> bool {
let mut weights: HashMap<String, Array> = HashMap::new();
weights.insert(
key.to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap(),
);
io::save_safetensors(path, &weights).is_ok()
}
#[cfg(unix)]
#[test]
fn load_non_utf8_leaf_model_shard_is_recoverable_error() {
use std::{ffi::OsString, os::unix::ffi::OsStringExt};
let dir = temp_dir("nonutf8-leaf");
fs::write(dir.join("config.json"), config_json("bert")).unwrap();
let mut raw = b"model".to_vec();
raw.push(0xFF);
raw.extend_from_slice(b".safetensors");
let bad_leaf = OsString::from_vec(raw);
if !try_write_one_tensor(&dir.join(&bad_leaf), "mock.weight") {
return; }
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let Err(err) = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
) else {
panic!("a non-UTF-8 model*.safetensors leaf must be a recoverable error, not a panic");
};
assert!(
matches!(err, Error::FileIo(_)),
"expected FileIo error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("non-UTF-8 file name") && msg.contains("shard pattern"),
"the error should explain the non-UTF-8 shard-name rejection; got: {msg}"
);
assert!(
msg.contains(&dir.display().to_string()),
"the error should name the offending shard path; got: {msg}"
);
}
#[cfg(unix)]
#[test]
fn load_non_utf8_leaf_shard_wins_over_stale_weight_fallback() {
use std::{ffi::OsString, os::unix::ffi::OsStringExt};
let dir = temp_dir("nonutf8-stale");
fs::write(dir.join("config.json"), config_json("bert")).unwrap();
let mut raw = b"model".to_vec();
raw.push(0xFF);
raw.extend_from_slice(b".safetensors");
let bad_leaf = OsString::from_vec(raw);
if !try_write_one_tensor(&dir.join(&bad_leaf), "primary.weight") {
return; }
assert!(
try_write_one_tensor(&dir.join("weights.safetensors"), "stale.weight"),
"the legacy fallback shard must write on any filesystem"
);
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let Err(err) = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
) else {
panic!(
"a non-UTF-8 model*.safetensors primary shard must fail the load, NOT silently fall \
back to the stale weight*.safetensors snapshot"
);
};
assert!(
matches!(err, Error::FileIo(_)),
"expected FileIo error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("non-UTF-8 file name") && msg.contains("shard pattern"),
"the load must fail with the non-UTF-8 shard error, not silently load the stale \
weight*.safetensors fallback; got: {msg}"
);
}
#[cfg(unix)]
#[test]
fn load_non_utf8_leaf_shard_nested_under_subfolder_is_recoverable_error() {
use std::{ffi::OsString, os::unix::ffi::OsStringExt};
let dir = temp_dir("nonutf8-nested-leaf");
fs::write(dir.join("config.json"), config_json("bert")).unwrap();
let nested = dir.join("text_model");
fs::create_dir_all(&nested).unwrap();
let mut raw = b"model".to_vec();
raw.push(0xFF);
raw.extend_from_slice(b".safetensors");
let bad_leaf = OsString::from_vec(raw);
if !try_write_one_tensor(&nested.join(&bad_leaf), "encoder.weight") {
return; }
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let Err(err) = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
) else {
panic!("a non-UTF-8 leaf shard nested under a subfolder must be a recoverable error");
};
assert!(
matches!(err, Error::FileIo(_)),
"expected FileIo error; got {err:?}"
);
let msg = err.to_string();
assert!(
msg.contains("non-UTF-8 file name") && msg.contains("shard pattern"),
"the error should explain the non-UTF-8 shard-name rejection; got: {msg}"
);
assert!(
msg.contains(&nested.display().to_string()),
"the error should name the offending nested shard path; got: {msg}"
);
}
#[cfg(unix)]
#[test]
fn load_broken_symlink_pooling_config_is_recoverable_error() {
let dir = temp_dir("pooling-broken-symlink");
write_model_dir(&dir, "bert");
let pooling_dir = dir.join("1_Pooling");
fs::create_dir_all(&pooling_dir).unwrap();
std::os::unix::fs::symlink(
dir.join("nonexistent_pooling_target.json"),
pooling_dir.join("config.json"),
)
.unwrap();
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let Err(err) = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
) else {
panic!(
"a present-but-broken (dangling-symlink) 1_Pooling/config.json must fail the load, \
NOT be silently treated as absent and fall back to default pooling"
);
};
let expected_pooling_config_path = pooling_dir.join("config.json");
match &err {
Error::FileIo(payload) => {
assert_eq!(
payload.op(),
FileOp::Open,
"expected the OPEN phase for a dangling-symlink target; got {:?}",
payload.op()
);
assert_eq!(
payload.inner().kind(),
std::io::ErrorKind::NotFound,
"expected the inner io::Error::NotFound from following a dangling symlink to a missing \
target; got {:?}",
payload.inner().kind()
);
assert_eq!(
payload.path(),
expected_pooling_config_path.as_path(),
"expected the FileIo path to be the planted broken-symlink config path"
);
}
other => panic!("expected a typed FileIo error; got {other:?}"),
}
}
#[cfg(unix)]
#[test]
fn load_broken_parent_symlink_pooling_dir_is_recoverable_error() {
let dir = temp_dir("pooling-broken-parent-symlink");
write_model_dir(&dir, "bert");
std::os::unix::fs::symlink(dir.join("nonexistent_pooling_dir"), dir.join("1_Pooling")).unwrap();
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let Err(err) = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
) else {
panic!(
"a present-but-broken (dangling-symlink) 1_Pooling parent directory must fail the \
load, NOT be silently treated as absent and fall back to default pooling"
);
};
let expected_parent_path = dir.join("1_Pooling");
let Error::FileIo(payload) = &err else {
panic!("expected a typed FileIo error; got {err:?}");
};
assert_eq!(
payload.op(),
FileOp::Stat,
"broken-parent-symlink probe MUST surface as FileOp::Stat (the metadata \
follow-the-link call that failed); got {:?}",
payload.op()
);
assert_eq!(
payload.inner().kind(),
std::io::ErrorKind::NotFound,
"broken-parent-symlink MUST carry the REAL io::Error::NotFound from \
following the dangling parent symlink; got {:?}",
payload.inner().kind()
);
assert_eq!(
payload.path(),
expected_parent_path.as_path(),
"the FileIo path MUST be the broken 1_Pooling parent (NOT the unreachable \
1_Pooling/config.json child) so callers can distinguish parent-vs-child \
failure modes; got {}",
payload.path().display()
);
}
#[cfg(unix)]
#[test]
fn load_resolvable_parent_symlink_pooling_dir_with_no_config_is_absent() {
let dir = temp_dir("pooling-resolvable-parent-symlink");
write_model_dir(&dir, "bert");
let target_dir = dir.join("real_pooling_target");
fs::create_dir_all(&target_dir).unwrap();
std::os::unix::fs::symlink(&target_dir, dir.join("1_Pooling")).unwrap();
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let ctx = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
)
.expect(
"a resolvable 1_Pooling symlink pointing at a real (empty) directory must still load \
(no config.json inside ⇒ no pooling override)",
);
assert!(
ctx.pooling.is_none(),
"no 1_Pooling/config.json inside the resolved target ⇒ no pooling override"
);
}
#[test]
fn load_absent_pooling_config_still_loads_with_default_pooling() {
let dir = temp_dir("pooling-genuinely-absent");
write_model_dir(&dir, "bert");
assert!(
!dir.join("1_Pooling").exists(),
"this fixture must not have a 1_Pooling directory"
);
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let ctx = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
)
.expect("a genuinely-absent 1_Pooling/config.json must still load (default pooling)");
assert!(
ctx.pooling.is_none(),
"no 1_Pooling/config.json ⇒ no pooling override"
);
}
#[cfg(unix)]
#[test]
fn load_permission_denied_pooling_config_is_recoverable_error() {
use std::os::unix::fs::PermissionsExt;
let dir = temp_dir("pooling-perm-denied");
write_model_dir(&dir, "bert");
let pooling_dir = dir.join("1_Pooling");
fs::create_dir_all(&pooling_dir).unwrap();
fs::write(
pooling_dir.join("config.json"),
r#"{"word_embedding_dimension": 4, "pooling_mode_cls_token": true}"#,
)
.unwrap();
fs::set_permissions(&pooling_dir, fs::Permissions::from_mode(0o000)).unwrap();
let enforced = fs::symlink_metadata(pooling_dir.join("config.json")).is_err();
let registry = EmbeddingModelTypeRegistry::new().with("bert", mock_constructor());
let result = load(
&EmbeddingModelConfiguration::from_directory(&dir),
®istry,
);
fs::set_permissions(&pooling_dir, fs::Permissions::from_mode(0o755)).unwrap();
if !enforced {
eprintln!(
"skipping permission-denied pooling-config assertion: this environment \
does not enforce directory search permission"
);
return;
}
let Err(err) = result else {
panic!(
"a permission-denied 1_Pooling/config.json must fail the load, NOT be silently \
treated as absent and fall back to default pooling"
);
};
let Error::FileIo(payload) = &err else {
panic!("expected a typed FileIo error; got {err:?}");
};
assert_eq!(
payload.op(),
FileOp::Stat,
"permission-denied symlink_metadata MUST surface as FileOp::Stat \
(the lstat that failed), not Other; got {:?}",
payload.op()
);
assert_eq!(
payload.inner().kind(),
std::io::ErrorKind::PermissionDenied,
"permission-denied path MUST carry the REAL io::Error from \
symlink_metadata (PermissionDenied), not a synthetic NotFound from \
Path::exists(); got {:?}",
payload.inner().kind()
);
assert!(
payload.path().ends_with("1_Pooling/config.json"),
"expected the FileIo path to end with 1_Pooling/config.json; got {}",
payload.path().display()
);
}
fn capturing_constructor(
slot: std::sync::Arc<std::sync::Mutex<Vec<String>>>,
) -> EmbeddingModelConstructor {
Box::new(
move |loaded: &LoadedEmbeddingModel| -> Result<Box<dyn EmbeddingModel>, Error> {
let mut keys: Vec<String> = loaded.weights_ref().keys().cloned().collect();
keys.sort();
*slot.lock().unwrap() = keys;
Ok(Box::new(MockEmbedding))
},
)
}
static CWD_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
struct RestoreCwd {
saved: PathBuf,
_lock: std::sync::MutexGuard<'static, ()>,
}
impl RestoreCwd {
fn change_to(to: &Path) -> Self {
let lock = CWD_TEST_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let saved = std::env::current_dir().expect("read current dir");
std::env::set_current_dir(to).expect("set current dir");
RestoreCwd { saved, _lock: lock }
}
}
impl Drop for RestoreCwd {
fn drop(&mut self) {
let _ = std::env::set_current_dir(&self.saved);
}
}
fn write_relative_model_dir(
tag: &str,
leaf: &str,
weight_file: &str,
weight_key: &str,
) -> (PathBuf, String) {
let parent = temp_dir(tag);
let model_dir = parent.join(leaf);
fs::create_dir_all(&model_dir).unwrap();
fs::write(model_dir.join("config.json"), config_json("bert")).unwrap();
write_tokenizer(&model_dir);
let mut weights: EmbeddingWeights = HashMap::new();
weights.insert(
weight_key.to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0], &(2usize,)).unwrap(),
);
io::save_safetensors(&model_dir.join(weight_file), &weights).unwrap();
(parent, leaf.to_owned())
}
#[test]
fn load_from_dot_directory_keeps_root_shard_keys_verbatim() {
let (parent, leaf) = write_relative_model_dir("dot-dir", "ckpt", "model.safetensors", "root.w");
let keys = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let registry =
EmbeddingModelTypeRegistry::new().with("bert", capturing_constructor(keys.clone()));
let _cwd = RestoreCwd::change_to(&parent.join(&leaf));
load(&EmbeddingModelConfiguration::from_directory("."), ®istry)
.expect("from_directory(\".\") with a root model.safetensors must load, not error");
assert_eq!(
*keys.lock().unwrap(),
vec!["root.w".to_string()],
"a root shard discovered via `.` must keep its keys verbatim (no prefix)"
);
}
#[test]
fn load_from_dot_slash_subdir_does_not_rewrite_root_shard_keys() {
let (parent, leaf) =
write_relative_model_dir("dotslash-model", "ckpt", "model.safetensors", "root.w");
let keys = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let registry =
EmbeddingModelTypeRegistry::new().with("bert", capturing_constructor(keys.clone()));
let _cwd = RestoreCwd::change_to(&parent);
let spelling = format!("./{leaf}");
load(
&EmbeddingModelConfiguration::from_directory(&spelling),
®istry,
)
.unwrap_or_else(|e| panic!("from_directory({spelling:?}) must load: {e}"));
let got = keys.lock().unwrap().clone();
assert_eq!(
got,
vec!["root.w".to_string()],
"a root shard reached via `./{leaf}` must keep keys verbatim, NOT be rewritten `{leaf}.<key>`"
);
assert!(
!got.iter().any(|k| k.starts_with(&format!("{leaf}."))),
"root keys must not gain a `{leaf}.` prefix; got {got:?}"
);
}
#[test]
fn load_from_dot_slash_subdir_loads_legacy_weight_glob_shard() {
let (parent, leaf) =
write_relative_model_dir("dotslash-weight", "ckpt", "weights.safetensors", "legacy.w");
let keys = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let registry =
EmbeddingModelTypeRegistry::new().with("bert", capturing_constructor(keys.clone()));
let _cwd = RestoreCwd::change_to(&parent);
let spelling = format!("./{leaf}");
load(
&EmbeddingModelConfiguration::from_directory(&spelling),
®istry,
)
.unwrap_or_else(|e| panic!("legacy weight*.safetensors via {spelling:?} must load: {e}"));
assert_eq!(
*keys.lock().unwrap(),
vec!["legacy.w".to_string()],
"a legacy root weight*.safetensors reached via `./{leaf}` must keep keys verbatim"
);
}
#[test]
fn load_relative_directory_still_prefixes_genuine_nested_component_shards() {
let (parent, leaf) =
write_relative_model_dir("rel-nested", "ckpt", "model.safetensors", "embed.w");
let nested = parent.join(&leaf).join("vision_model");
fs::create_dir_all(&nested).unwrap();
let mut nested_weights: EmbeddingWeights = HashMap::new();
nested_weights.insert(
"enc.w".to_owned(),
Array::from_slice::<f32>(&[3.0, 4.0], &(2usize,)).unwrap(),
);
io::save_safetensors(&nested.join("model.safetensors"), &nested_weights).unwrap();
for spelling in [leaf.clone(), format!("./{leaf}")] {
let keys = std::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let registry =
EmbeddingModelTypeRegistry::new().with("bert", capturing_constructor(keys.clone()));
let _cwd = RestoreCwd::change_to(&parent);
load(
&EmbeddingModelConfiguration::from_directory(&spelling),
®istry,
)
.unwrap_or_else(|e| {
panic!("model with a nested component shard via {spelling:?} must load: {e}")
});
drop(_cwd);
let got = keys.lock().unwrap().clone();
assert_eq!(
got,
vec!["embed.w".to_string(), "vision_model.enc.w".to_string()],
"via {spelling:?}: the root shard's keys must stay verbatim and the nested \
component shard's keys must be prefixed `vision_model.`; got {got:?}"
);
}
}