use std::collections::{HashMap, HashSet};
use std::path::Path;
use flodl::{DType, Device, Graph, Result, Tensor, TensorError};
use safetensors::{tensor::TensorView, Dtype, SafeTensors};
use crate::path::hf_key_from_flodl_key;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExpectedParam {
pub key: String,
pub shape: Vec<i64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShapeMismatch {
pub key: String,
pub expected: Vec<i64>,
pub found: Vec<i64>,
}
#[derive(Debug, Default, Clone)]
pub struct LoadValidation {
pub missing: Vec<String>,
pub unused: Vec<String>,
pub shape_mismatches: Vec<ShapeMismatch>,
}
impl LoadValidation {
pub fn is_ok(&self) -> bool {
self.missing.is_empty() && self.unused.is_empty() && self.shape_mismatches.is_empty()
}
pub fn into_result(self) -> Result<()> {
self.into_result_impl(false)
}
pub fn into_result_allow_unused(self) -> Result<()> {
self.into_result_impl(true)
}
fn into_result_impl(mut self, allow_unused: bool) -> Result<()> {
if allow_unused {
self.unused.clear();
}
if self.is_ok() {
return Ok(());
}
let mut msg = String::from("safetensors checkpoint does not match model:\n");
if !self.missing.is_empty() {
msg.push_str(&format!(
" {} missing key(s) (model expects, checkpoint lacks):\n",
self.missing.len(),
));
for k in self.missing.iter().take(20) {
msg.push_str(&format!(" - {k}\n"));
}
if self.missing.len() > 20 {
msg.push_str(&format!(" ... and {} more\n", self.missing.len() - 20));
}
}
if !self.unused.is_empty() {
msg.push_str(&format!(
" {} unused key(s) (checkpoint has, model lacks):\n",
self.unused.len(),
));
for k in self.unused.iter().take(20) {
msg.push_str(&format!(" - {k}\n"));
}
if self.unused.len() > 20 {
msg.push_str(&format!(" ... and {} more\n", self.unused.len() - 20));
}
}
if !self.shape_mismatches.is_empty() {
msg.push_str(&format!(
" {} shape mismatch(es):\n",
self.shape_mismatches.len(),
));
for m in self.shape_mismatches.iter().take(20) {
msg.push_str(&format!(
" - {}: expected {:?}, found {:?}\n",
m.key, m.expected, m.found,
));
}
if self.shape_mismatches.len() > 20 {
msg.push_str(&format!(
" ... and {} more\n",
self.shape_mismatches.len() - 20,
));
}
}
Err(TensorError::new(&msg))
}
}
pub fn validate_keys(
expected: &[ExpectedParam],
actual: &HashMap<String, Vec<i64>>,
) -> LoadValidation {
let expected_keys: HashSet<&str> = expected.iter().map(|p| p.key.as_str()).collect();
let mut v = LoadValidation::default();
for p in expected {
match actual.get(&p.key) {
None => v.missing.push(p.key.clone()),
Some(found) if found != &p.shape => {
v.shape_mismatches.push(ShapeMismatch {
key: p.key.clone(),
expected: p.shape.clone(),
found: found.clone(),
});
}
Some(_) => {}
}
}
for k in actual.keys() {
if !expected_keys.contains(k.as_str()) {
v.unused.push(k.clone());
}
}
v.missing.sort();
v.unused.sort();
v.shape_mismatches.sort_by(|a, b| a.key.cmp(&b.key));
v
}
pub fn expected_from_graph(graph: &Graph) -> Vec<ExpectedParam> {
let mut out = Vec::new();
for (k, p) in graph.named_parameters() {
out.push(ExpectedParam {
key: hf_key_from_flodl_key(&k),
shape: p.variable.shape(),
});
}
for (k, b) in graph.named_buffers() {
out.push(ExpectedParam {
key: hf_key_from_flodl_key(&k),
shape: b.shape(),
});
}
out
}
pub fn load_safetensors_into_graph(graph: &Graph, bytes: &[u8]) -> Result<()> {
load_safetensors_into_graph_with_rename(graph, bytes, |k| k.to_string())
}
pub fn load_safetensors_into_graph_with_rename<F>(
graph: &Graph,
bytes: &[u8],
rename: F,
) -> Result<()>
where
F: Fn(&str) -> String,
{
load_safetensors_core(graph, bytes, &rename, false)?;
Ok(())
}
pub fn load_safetensors_into_graph_with_rename_allow_unused<F>(
graph: &Graph,
bytes: &[u8],
rename: F,
) -> Result<Vec<String>>
where
F: Fn(&str) -> String,
{
load_safetensors_core(graph, bytes, &rename, true)
}
fn load_safetensors_core(
graph: &Graph,
bytes: &[u8],
rename: &dyn Fn(&str) -> String,
allow_unused: bool,
) -> Result<Vec<String>> {
let st = SafeTensors::deserialize(bytes)
.map_err(|e| TensorError::new(&format!("safetensors parse error: {e}")))?;
let mut canonical_to_original: HashMap<String, String> = HashMap::new();
let mut actual_shapes: HashMap<String, Vec<i64>> = HashMap::new();
for name in st.names() {
let canonical = rename(name);
if let Some(prev) = canonical_to_original.insert(canonical.clone(), name.to_string()) {
return Err(TensorError::new(&format!(
"safetensors key rename collision: both {prev:?} and {name:?} \
map to canonical key {canonical:?}",
)));
}
let view = st.tensor(name)
.map_err(|e| TensorError::new(&format!("safetensors tensor lookup {name}: {e}")))?;
actual_shapes.insert(
canonical,
view.shape().iter().map(|&s| s as i64).collect(),
);
}
let expected = expected_from_graph(graph);
let validation = validate_keys(&expected, &actual_shapes);
let unused = validation.unused.clone();
if allow_unused {
validation.into_result_allow_unused()?;
} else {
validation.into_result()?;
}
for (flodl_key, param) in graph.named_parameters() {
let hf_key = hf_key_from_flodl_key(&flodl_key);
let original = canonical_to_original.get(&hf_key).ok_or_else(|| {
TensorError::new(&format!(
"canonical key {hf_key:?} missing from checkpoint after rename \
(validation should have caught this)",
))
})?;
let view = st.tensor(original)
.map_err(|e| TensorError::new(&format!("safetensors tensor {original}: {e}")))?;
let device = param.variable.data().device();
let src = tensor_view_to_tensor(&view, device)?;
param.variable.set_data(src);
}
for (flodl_key, buffer) in graph.named_buffers() {
let hf_key = hf_key_from_flodl_key(&flodl_key);
let original = canonical_to_original.get(&hf_key).ok_or_else(|| {
TensorError::new(&format!(
"canonical buffer key {hf_key:?} missing after rename",
))
})?;
let view = st.tensor(original)
.map_err(|e| TensorError::new(&format!("safetensors tensor {original}: {e}")))?;
let src = tensor_view_to_tensor(&view, buffer.device())?;
buffer.set(src);
}
Ok(unused)
}
pub fn bert_legacy_key_rename(checkpoint_key: &str) -> String {
if checkpoint_key == "cls.predictions.bias" {
return "cls.predictions.decoder.bias".to_string();
}
if checkpoint_key == "lm_head.bias" {
return "lm_head.decoder.bias".to_string();
}
if let Some(prefix) = checkpoint_key.strip_suffix("LayerNorm.gamma") {
format!("{prefix}LayerNorm.weight")
} else if let Some(prefix) = checkpoint_key.strip_suffix("LayerNorm.beta") {
format!("{prefix}LayerNorm.bias")
} else {
checkpoint_key.to_string()
}
}
pub fn bert_legacy_layernorm_rename(checkpoint_key: &str) -> String {
if let Some(prefix) = checkpoint_key.strip_suffix("LayerNorm.gamma") {
format!("{prefix}LayerNorm.weight")
} else if let Some(prefix) = checkpoint_key.strip_suffix("LayerNorm.beta") {
format!("{prefix}LayerNorm.bias")
} else {
checkpoint_key.to_string()
}
}
pub fn hf_canonical_save_key(hf_key: &str) -> String {
if hf_key == "cls.predictions.decoder.bias" {
return "cls.predictions.bias".to_string();
}
if hf_key == "lm_head.decoder.bias" {
return "lm_head.bias".to_string();
}
hf_key.to_string()
}
fn is_pooler_key(key: &str) -> bool {
let normalised = key.replace('/', ".");
normalised.ends_with("pooler.dense.weight")
|| normalised.ends_with("pooler.dense.bias")
|| normalised.ends_with("pooler.weight")
|| normalised.ends_with("pooler.bias")
}
pub fn weights_have_pooler(weights: &[u8]) -> Result<bool> {
let st = SafeTensors::deserialize(weights)
.map_err(|e| TensorError::new(&format!("safetensors parse error: {e}")))?;
Ok(st.names().iter().any(|n| is_pooler_key(n)))
}
pub fn keys_have_pooler(keys: &[String]) -> bool {
keys.iter().any(|k| is_pooler_key(k))
}
pub fn load_safetensors_file_into_graph(graph: &Graph, path: &Path) -> Result<()> {
let bytes = std::fs::read(path).map_err(|e| {
TensorError::new(&format!("safetensors read {}: {e}", path.display()))
})?;
load_safetensors_into_graph(graph, &bytes)
}
pub fn load_safetensors_file_into_graph_with_rename<F>(
graph: &Graph,
path: &Path,
rename: F,
) -> Result<()>
where
F: Fn(&str) -> String,
{
let bytes = std::fs::read(path).map_err(|e| {
TensorError::new(&format!("safetensors read {}: {e}", path.display()))
})?;
load_safetensors_into_graph_with_rename(graph, &bytes, rename)
}
pub fn save_safetensors_from_graph(graph: &Graph) -> Result<Vec<u8>> {
use std::collections::BTreeMap;
let mut entries: BTreeMap<String, (Dtype, Vec<usize>, Vec<u8>)> = BTreeMap::new();
for (flodl_key, param) in graph.named_parameters() {
let hf_key = hf_canonical_save_key(&hf_key_from_flodl_key(&flodl_key));
let shape: Vec<usize> = param.variable.shape().iter().map(|&d| d as usize).collect();
let dtype = param.variable.data().dtype();
let bytes = param.variable.data().to_blob()?;
if entries.contains_key(&hf_key) {
return Err(TensorError::new(&format!(
"save_safetensors: HF key {hf_key:?} collision \
— multiple distinct flodl tensors map to the same name; \
fix the conflicting `tag(...)` in the graph",
)));
}
entries.insert(hf_key, (dtype_to_safetensors(dtype)?, shape, bytes));
}
for (flodl_key, buffer) in graph.named_buffers() {
let hf_key = hf_canonical_save_key(&hf_key_from_flodl_key(&flodl_key));
let shape: Vec<usize> = buffer.shape().iter().map(|&d| d as usize).collect();
let dtype = buffer.get().dtype();
let bytes = buffer.get().to_blob()?;
if entries.contains_key(&hf_key) {
return Err(TensorError::new(&format!(
"save_safetensors: HF key {hf_key:?} collision \
— buffer collides with a parameter or another buffer",
)));
}
entries.insert(hf_key, (dtype_to_safetensors(dtype)?, shape, bytes));
}
let views: HashMap<String, TensorView<'_>> = entries.iter()
.map(|(k, (dtype, shape, bytes))| {
let view = TensorView::new(*dtype, shape.clone(), bytes.as_slice())
.map_err(|e| TensorError::new(&format!(
"safetensors view build for {k:?}: {e}",
)))?;
Ok::<(String, TensorView<'_>), TensorError>((k.clone(), view))
})
.collect::<std::result::Result<_, _>>()?;
safetensors::serialize(&views, &None)
.map_err(|e| TensorError::new(&format!("safetensors serialize: {e}")))
}
fn dtype_to_safetensors(dtype: DType) -> Result<Dtype> {
match dtype {
DType::Float32 => Ok(Dtype::F32),
DType::Float64 => Ok(Dtype::F64),
DType::Float16 => Ok(Dtype::F16),
DType::BFloat16 => Ok(Dtype::BF16),
DType::Int32 | DType::Int64 => Err(TensorError::new(&format!(
"save_safetensors: integer dtype {dtype:?} not supported \
— only floating-point parameters / buffers can be serialised",
))),
}
}
pub fn save_safetensors_file_from_graph(graph: &Graph, path: &Path) -> Result<()> {
let bytes = save_safetensors_from_graph(graph)?;
std::fs::write(path, &bytes).map_err(|e| {
TensorError::new(&format!("safetensors write {}: {e}", path.display()))
})
}
fn tensor_view_to_tensor(view: &TensorView, target_device: Device) -> Result<Tensor> {
let shape: Vec<i64> = view.shape().iter().map(|&s| s as i64).collect();
let dtype = match view.dtype() {
Dtype::F32 => DType::Float32,
Dtype::F64 => DType::Float64,
Dtype::F16 => DType::Float16,
Dtype::BF16 => DType::BFloat16,
other => {
return Err(TensorError::new(&format!(
"unsupported safetensors dtype {other:?} — floats (F32/F64/BF16/F16) only",
)));
}
};
Tensor::from_blob(view.data(), &shape, dtype, target_device)
}
pub fn tensor_view_to_f32_vec(view: &TensorView) -> Result<Vec<f32>> {
let bytes = view.data();
match view.dtype() {
Dtype::F32 => {
if bytes.len() % 4 != 0 {
return Err(TensorError::new(&format!(
"F32 tensor byte length {} is not a multiple of 4", bytes.len(),
)));
}
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Ok(out)
}
Dtype::F64 => {
if bytes.len() % 8 != 0 {
return Err(TensorError::new(&format!(
"F64 tensor byte length {} is not a multiple of 8", bytes.len(),
)));
}
let mut out = Vec::with_capacity(bytes.len() / 8);
for chunk in bytes.chunks_exact(8) {
let bits = f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3],
chunk[4], chunk[5], chunk[6], chunk[7],
]);
out.push(bits as f32);
}
Ok(out)
}
Dtype::BF16 => {
if bytes.len() % 2 != 0 {
return Err(TensorError::new(&format!(
"BF16 tensor byte length {} is not a multiple of 2", bytes.len(),
)));
}
let mut out = Vec::with_capacity(bytes.len() / 2);
for chunk in bytes.chunks_exact(2) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
out.push(f32::from_bits((bits as u32) << 16));
}
Ok(out)
}
Dtype::F16 => {
if bytes.len() % 2 != 0 {
return Err(TensorError::new(&format!(
"F16 tensor byte length {} is not a multiple of 2", bytes.len(),
)));
}
let mut out = Vec::with_capacity(bytes.len() / 2);
for chunk in bytes.chunks_exact(2) {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
out.push(f16_bits_to_f32(bits));
}
Ok(out)
}
other => Err(TensorError::new(&format!(
"unsupported safetensors dtype {other:?} — floats (F32/F64/BF16/F16) only",
))),
}
}
fn f16_bits_to_f32(bits: u16) -> f32 {
let sign = (bits >> 15) as u32 & 0x1;
let exp = (bits >> 10) as u32 & 0x1f;
let mantissa = bits as u32 & 0x3ff;
let out_bits: u32 = if exp == 0 {
if mantissa == 0 {
sign << 31
} else {
let mut m = mantissa;
let mut e: i32 = -14;
while (m & 0x400) == 0 {
m <<= 1;
e -= 1;
}
m &= 0x3ff;
let f32_exp = (e + 127) as u32 & 0xff;
(sign << 31) | (f32_exp << 23) | (m << 13)
}
} else if exp == 0x1f {
(sign << 31) | (0xff << 23) | (mantissa << 13)
} else {
let f32_exp = (exp + 127 - 15) & 0xff;
(sign << 31) | (f32_exp << 23) | (mantissa << 13)
};
f32::from_bits(out_bits)
}
#[cfg(test)]
mod tests {
use super::*;
fn actual_map(entries: &[(&str, &[i64])]) -> HashMap<String, Vec<i64>> {
entries.iter().map(|(k, s)| ((*k).to_string(), s.to_vec())).collect()
}
#[test]
fn keys_have_pooler_detects_slash_separator() {
let keys = vec![
"encoder/layer/0/attention/query/weight".to_string(),
"pooler/dense/weight".to_string(),
"pooler/dense/bias".to_string(),
];
assert!(keys_have_pooler(&keys));
}
#[test]
fn keys_have_pooler_detects_dot_separator() {
let keys = vec![
"encoder.layer.0.attention.query.weight".to_string(),
"pooler.dense.weight".to_string(),
];
assert!(keys_have_pooler(&keys));
}
#[test]
fn keys_have_pooler_returns_false_for_encoder_only() {
let keys = vec![
"encoder/layer/0/attention/query/weight".to_string(),
"encoder/layer/0/attention/query/bias".to_string(),
];
assert!(!keys_have_pooler(&keys));
}
#[test]
fn keys_have_pooler_does_not_match_substrings() {
let keys = vec!["encoder/some_pooler_thing/weight".to_string()];
assert!(!keys_have_pooler(&keys));
}
#[test]
fn weights_have_pooler_detects_bert_style() {
let bytes = serialize_entries(&[
("bert.embeddings.word_embeddings.weight",
Dtype::F32, vec![2, 4], f32_le_bytes(&[0.0; 8])),
("bert.pooler.dense.weight",
Dtype::F32, vec![4, 4], f32_le_bytes(&[0.0; 16])),
("bert.pooler.dense.bias",
Dtype::F32, vec![4], f32_le_bytes(&[0.0; 4])),
]);
assert!(weights_have_pooler(&bytes).unwrap());
}
#[test]
fn weights_have_pooler_detects_albert_style_flat() {
let bytes = serialize_entries(&[
("albert.embeddings.word_embeddings.weight",
Dtype::F32, vec![2, 4], f32_le_bytes(&[0.0; 8])),
("albert.pooler.weight",
Dtype::F32, vec![4, 4], f32_le_bytes(&[0.0; 16])),
("albert.pooler.bias",
Dtype::F32, vec![4], f32_le_bytes(&[0.0; 4])),
]);
assert!(weights_have_pooler(&bytes).unwrap());
}
#[test]
fn weights_have_pooler_returns_false_for_encoder_only() {
let bytes = serialize_entries(&[
("roberta.embeddings.word_embeddings.weight",
Dtype::F32, vec![2, 4], f32_le_bytes(&[0.0; 8])),
("roberta.encoder.layer.0.attention.self.query.weight",
Dtype::F32, vec![4, 4], f32_le_bytes(&[0.0; 16])),
]);
assert!(!weights_have_pooler(&bytes).unwrap());
}
#[test]
fn weights_have_pooler_errors_on_invalid_safetensors() {
let err = weights_have_pooler(b"not a safetensors blob")
.unwrap_err()
.to_string();
assert!(err.contains("safetensors parse error"), "got: {err}");
}
#[test]
fn bert_legacy_layernorm_rename_rewrites_gamma_and_beta() {
assert_eq!(
bert_legacy_layernorm_rename("bert.embeddings.LayerNorm.gamma"),
"bert.embeddings.LayerNorm.weight",
);
assert_eq!(
bert_legacy_layernorm_rename("bert.embeddings.LayerNorm.beta"),
"bert.embeddings.LayerNorm.bias",
);
}
#[test]
fn bert_legacy_layernorm_rename_passthrough_for_modern_keys() {
let key = "bert.embeddings.LayerNorm.weight";
assert_eq!(bert_legacy_layernorm_rename(key), key);
}
#[test]
fn bert_legacy_layernorm_rename_does_not_touch_mlm_alias() {
let k = "cls.predictions.bias";
assert_eq!(bert_legacy_layernorm_rename(k), k);
}
#[test]
fn hf_canonical_save_key_inverts_mlm_decoder_bias_alias() {
assert_eq!(
hf_canonical_save_key("cls.predictions.decoder.bias"),
"cls.predictions.bias",
);
assert_eq!(
hf_canonical_save_key("lm_head.decoder.bias"),
"lm_head.bias",
);
}
#[test]
fn hf_canonical_save_key_passthrough_for_unrelated_keys() {
for k in [
"bert.embeddings.word_embeddings.weight",
"bert.embeddings.LayerNorm.weight",
"bert.encoder.layer.0.attention.self.query.bias",
] {
assert_eq!(hf_canonical_save_key(k), k);
}
}
#[test]
fn all_keys_match_returns_ok() {
let expected = vec![
ExpectedParam { key: "bert.embeddings.word_embeddings.weight".into(), shape: vec![30522, 768] },
ExpectedParam { key: "bert.pooler.dense.bias".into(), shape: vec![768] },
];
let actual = actual_map(&[
("bert.embeddings.word_embeddings.weight", &[30522, 768]),
("bert.pooler.dense.bias", &[768]),
]);
let v = validate_keys(&expected, &actual);
assert!(v.is_ok());
assert!(v.into_result().is_ok());
}
#[test]
fn missing_key_is_reported() {
let expected = vec![
ExpectedParam { key: "bert.pooler.dense.weight".into(), shape: vec![768, 768] },
];
let actual = actual_map(&[]);
let v = validate_keys(&expected, &actual);
assert_eq!(v.missing, vec!["bert.pooler.dense.weight"]);
assert!(v.unused.is_empty());
assert!(v.shape_mismatches.is_empty());
}
#[test]
fn unused_checkpoint_key_is_reported() {
let expected: Vec<ExpectedParam> = Vec::new();
let actual = actual_map(&[("bert.something.extra", &[4])]);
let v = validate_keys(&expected, &actual);
assert_eq!(v.unused, vec!["bert.something.extra"]);
assert!(v.missing.is_empty());
assert!(v.shape_mismatches.is_empty());
}
#[test]
fn shape_mismatch_is_reported() {
let expected = vec![
ExpectedParam {
key: "bert.embeddings.word_embeddings.weight".into(),
shape: vec![50257, 768],
},
];
let actual = actual_map(&[
("bert.embeddings.word_embeddings.weight", &[30522, 768]),
]);
let v = validate_keys(&expected, &actual);
assert!(v.missing.is_empty());
assert!(v.unused.is_empty());
assert_eq!(v.shape_mismatches.len(), 1);
assert_eq!(v.shape_mismatches[0].key, "bert.embeddings.word_embeddings.weight");
assert_eq!(v.shape_mismatches[0].expected, vec![50257, 768]);
assert_eq!(v.shape_mismatches[0].found, vec![30522, 768]);
}
#[test]
fn typo_queri_vs_query_reports_both_missing_and_unused() {
let expected = vec![
ExpectedParam { key: "bert.encoder.layer.0.attention.self.queri.weight".into(), shape: vec![768, 768] },
];
let actual = actual_map(&[
("bert.encoder.layer.0.attention.self.query.weight", &[768, 768]),
]);
let v = validate_keys(&expected, &actual);
assert_eq!(v.missing, vec!["bert.encoder.layer.0.attention.self.queri.weight"]);
assert_eq!(v.unused, vec!["bert.encoder.layer.0.attention.self.query.weight"]);
}
#[test]
fn mixed_failures_accumulate() {
let expected = vec![
ExpectedParam { key: "ok.weight".into(), shape: vec![4] },
ExpectedParam { key: "missing.weight".into(), shape: vec![8] },
ExpectedParam { key: "wrong_shape.weight".into(), shape: vec![16] },
];
let actual = actual_map(&[
("ok.weight", &[4]),
("wrong_shape.weight", &[32]),
("extra.weight", &[1]),
]);
let v = validate_keys(&expected, &actual);
assert_eq!(v.missing, vec!["missing.weight"]);
assert_eq!(v.unused, vec!["extra.weight"]);
assert_eq!(v.shape_mismatches.len(), 1);
assert_eq!(v.shape_mismatches[0].key, "wrong_shape.weight");
}
#[test]
fn into_result_error_message_lists_every_bucket() {
let expected = vec![
ExpectedParam { key: "m.w".into(), shape: vec![2] },
ExpectedParam { key: "sm.w".into(), shape: vec![3] },
];
let actual = actual_map(&[
("sm.w", &[4]),
("extra.w", &[1]),
]);
let v = validate_keys(&expected, &actual);
let err = v.into_result().unwrap_err().to_string();
assert!(err.contains("1 missing key"), "missing bucket not in msg: {err}");
assert!(err.contains("1 unused key"), "unused bucket not in msg: {err}");
assert!(err.contains("1 shape mismatch"), "shape bucket not in msg: {err}");
assert!(err.contains("m.w"));
assert!(err.contains("extra.w"));
assert!(err.contains("sm.w"));
assert!(err.contains("[3]"));
assert!(err.contains("[4]"));
}
#[test]
fn output_is_sorted_for_stable_messages() {
let expected = vec![
ExpectedParam { key: "z.w".into(), shape: vec![1] },
ExpectedParam { key: "a.w".into(), shape: vec![1] },
];
let actual = actual_map(&[
("m.w", &[1]),
("c.w", &[1]),
]);
let v = validate_keys(&expected, &actual);
assert_eq!(v.missing, vec!["a.w", "z.w"]);
assert_eq!(v.unused, vec!["c.w", "m.w"]);
}
#[test]
fn empty_everywhere_is_ok() {
let v = validate_keys(&[], &HashMap::new());
assert!(v.is_ok());
assert!(v.missing.is_empty());
assert!(v.unused.is_empty());
assert!(v.shape_mismatches.is_empty());
assert!(v.into_result().is_ok());
}
#[test]
fn into_result_truncates_long_missing_list() {
let expected: Vec<ExpectedParam> = (0..25)
.map(|i| ExpectedParam { key: format!("key.{i:02}"), shape: vec![1] })
.collect();
let v = validate_keys(&expected, &HashMap::new());
assert_eq!(v.missing.len(), 25);
let err = v.into_result().unwrap_err().to_string();
assert!(err.contains("25 missing key"), "header must show full count: {err}");
assert!(err.contains("... and 5 more"),
"truncation tail must show remaining count: {err}");
assert!(err.contains("key.00"));
assert!(err.contains("key.19"));
assert!(!err.contains("key.20"));
}
fn f32_le_bytes(data: &[f32]) -> Vec<u8> {
data.iter().flat_map(|f| f.to_le_bytes()).collect()
}
fn serialize_entries(entries: &[(&str, Dtype, Vec<usize>, Vec<u8>)]) -> Vec<u8> {
let views: HashMap<String, TensorView<'_>> = entries.iter().map(|(n, d, s, b)| {
(n.to_string(), TensorView::new(*d, s.clone(), b).unwrap())
}).collect();
safetensors::serialize(&views, &None).unwrap()
}
#[test]
fn load_safetensors_f32_roundtrip() {
use flodl::{FlowBuilder, Linear, Module, Variable};
let in_dim = 3_i64;
let out_dim = 2_i64;
let dev = Device::CPU;
let src_graph = FlowBuilder::new()
.through(Linear::on_device(in_dim, out_dim, dev).unwrap())
.tag("my.linear")
.build().unwrap();
let src_weight: Vec<f32> = (0..(in_dim * out_dim) as usize)
.map(|i| 1.0 + i as f32 * 0.25).collect();
let src_bias: Vec<f32> = (0..out_dim as usize)
.map(|i| -0.5 + i as f32).collect();
for (k, p) in src_graph.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let t = match hf.as_str() {
"my.linear.weight" => Tensor::from_f32(&src_weight, &[out_dim, in_dim], dev).unwrap(),
"my.linear.bias" => Tensor::from_f32(&src_bias, &[out_dim], dev).unwrap(),
other => panic!("unexpected key {other}"),
};
p.variable.set_data(t);
}
let w_bytes = f32_le_bytes(&src_weight);
let b_bytes = f32_le_bytes(&src_bias);
let bytes = serialize_entries(&[
("my.linear.weight", Dtype::F32, vec![out_dim as usize, in_dim as usize], w_bytes),
("my.linear.bias", Dtype::F32, vec![out_dim as usize], b_bytes),
]);
let dst_graph = FlowBuilder::new()
.through(Linear::on_device(in_dim, out_dim, dev).unwrap())
.tag("my.linear")
.build().unwrap();
load_safetensors_into_graph(&dst_graph, &bytes).unwrap();
let mut dst_weight: Option<Vec<f32>> = None;
let mut dst_bias: Option<Vec<f32>> = None;
for (k, p) in dst_graph.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let data = p.variable.data().to_f32_vec().unwrap();
match hf.as_str() {
"my.linear.weight" => dst_weight = Some(data),
"my.linear.bias" => dst_bias = Some(data),
other => panic!("unexpected key {other}"),
}
}
assert_eq!(dst_weight.unwrap(), src_weight);
assert_eq!(dst_bias.unwrap(), src_bias);
let _keep_alive: Vec<Variable> = dst_graph.parameters().into_iter().map(|p| p.variable).collect();
}
#[test]
fn load_safetensors_file_roundtrip() {
use flodl::{FlowBuilder, Linear};
use std::io::Write;
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(2, 1, dev).unwrap())
.tag("m")
.build().unwrap();
let w = vec![0.25_f32, 0.5];
let b = vec![1.5_f32];
for (k, p) in graph.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let t = match hf.as_str() {
"m.weight" => Tensor::from_f32(&w, &[1, 2], dev).unwrap(),
"m.bias" => Tensor::from_f32(&b, &[1], dev).unwrap(),
other => panic!("unexpected {other}"),
};
p.variable.set_data(t);
}
let bytes = serialize_entries(&[
("m.weight", Dtype::F32, vec![1, 2], f32_le_bytes(&w)),
("m.bias", Dtype::F32, vec![1], f32_le_bytes(&b)),
]);
let path = std::env::temp_dir().join(format!("flodl_hf_test_{}.safetensors", std::process::id()));
std::fs::File::create(&path).unwrap().write_all(&bytes).unwrap();
let fresh = FlowBuilder::new()
.through(Linear::on_device(2, 1, dev).unwrap())
.tag("m")
.build().unwrap();
load_safetensors_file_into_graph(&fresh, &path).unwrap();
let _ = std::fs::remove_file(&path);
for (k, p) in fresh.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let data = p.variable.data().to_f32_vec().unwrap();
match hf.as_str() {
"m.weight" => assert_eq!(data, w),
"m.bias" => assert_eq!(data, b),
other => panic!("unexpected {other}"),
}
}
}
#[test]
fn load_safetensors_bf16_preserves_dtype() {
use flodl::{DType, FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(2, 2, dev).unwrap())
.tag("m")
.build().unwrap();
let exact_w = [1.0_f32, 2.0, -0.5, 0.25];
let exact_b = [0.0_f32, -1.0];
let to_bf16_bytes = |data: &[f32]| -> Vec<u8> {
let mut out = Vec::with_capacity(data.len() * 2);
for &f in data {
let top = (f.to_bits() >> 16) as u16;
out.extend_from_slice(&top.to_le_bytes());
}
out
};
let bytes = serialize_entries(&[
("m.weight", Dtype::BF16, vec![2, 2], to_bf16_bytes(&exact_w)),
("m.bias", Dtype::BF16, vec![2], to_bf16_bytes(&exact_b)),
]);
load_safetensors_into_graph(&graph, &bytes).unwrap();
for (k, p) in graph.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
assert_eq!(p.variable.data().dtype(), DType::BFloat16,
"{hf}: dtype must be preserved as BF16 after load");
let data = p.variable.data().to_f32_vec().unwrap();
match hf.as_str() {
"m.weight" => assert_eq!(data, exact_w),
"m.bias" => assert_eq!(data, exact_b),
other => panic!("unexpected {other}"),
}
}
}
#[test]
fn load_safetensors_f16_preserves_dtype() {
use flodl::{DType, FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(1, 4, dev).unwrap())
.tag("m")
.build().unwrap();
let f16_bits: [u16; 4] = [0x3C00, 0xBC00, 0x3800, 0x0000];
let mut bytes_w = Vec::with_capacity(8);
for b in f16_bits {
bytes_w.extend_from_slice(&b.to_le_bytes());
}
let bias_bits: [u16; 1] = [0x3C00];
let bytes_b: Vec<u8> = bias_bits[0].to_le_bytes().to_vec();
let st_bytes = serialize_entries(&[
("m.weight", Dtype::F16, vec![4, 1], bytes_w),
("m.bias", Dtype::F16, vec![4], bytes_b.repeat(4)),
]);
load_safetensors_into_graph(&graph, &st_bytes).unwrap();
for (k, p) in graph.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
assert_eq!(p.variable.data().dtype(), DType::Float16,
"{hf}: dtype must be preserved as F16 after load");
let data = p.variable.data().to_f32_vec().unwrap();
match hf.as_str() {
"m.weight" => assert_eq!(data, vec![1.0, -1.0, 0.5, 0.0]),
"m.bias" => assert_eq!(data, vec![1.0, 1.0, 1.0, 1.0]),
other => panic!("unexpected {other}"),
}
}
}
#[test]
fn save_safetensors_f16_roundtrip_byte_exact() {
use flodl::{FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(1, 4, dev).unwrap())
.tag("m")
.build().unwrap();
let f16_bits: [u16; 4] = [0x3C00, 0xBC00, 0x3800, 0x0000];
let bytes_w: Vec<u8> = f16_bits.iter().flat_map(|b| b.to_le_bytes()).collect();
let bytes_b: Vec<u8> = (0..4).flat_map(|_| 0x3C00u16.to_le_bytes()).collect();
let src = serialize_entries(&[
("m.weight", Dtype::F16, vec![4, 1], bytes_w.clone()),
("m.bias", Dtype::F16, vec![4], bytes_b.clone()),
]);
load_safetensors_into_graph(&graph, &src).unwrap();
let saved = save_safetensors_from_graph(&graph).unwrap();
let saved_st = SafeTensors::deserialize(&saved).unwrap();
for (k, expected_bytes) in [("m.weight", &bytes_w), ("m.bias", &bytes_b)] {
let v = saved_st.tensor(k).unwrap();
assert_eq!(v.dtype(), Dtype::F16, "{k}: must save back as F16");
assert_eq!(v.data(), expected_bytes.as_slice(),
"{k}: F16 bytes must be bit-exact through load+save");
}
}
#[test]
fn save_safetensors_bf16_roundtrip_byte_exact() {
use flodl::{FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(2, 2, dev).unwrap())
.tag("m")
.build().unwrap();
let exact_w = [1.0_f32, 2.0, -0.5, 0.25];
let exact_b = [0.0_f32, -1.0];
let to_bf16_bytes = |data: &[f32]| -> Vec<u8> {
data.iter().flat_map(|f| ((f.to_bits() >> 16) as u16).to_le_bytes()).collect()
};
let bytes_w = to_bf16_bytes(&exact_w);
let bytes_b = to_bf16_bytes(&exact_b);
let src = serialize_entries(&[
("m.weight", Dtype::BF16, vec![2, 2], bytes_w.clone()),
("m.bias", Dtype::BF16, vec![2], bytes_b.clone()),
]);
load_safetensors_into_graph(&graph, &src).unwrap();
let saved = save_safetensors_from_graph(&graph).unwrap();
let saved_st = SafeTensors::deserialize(&saved).unwrap();
for (k, expected_bytes) in [("m.weight", &bytes_w), ("m.bias", &bytes_b)] {
let v = saved_st.tensor(k).unwrap();
assert_eq!(v.dtype(), Dtype::BF16, "{k}: must save back as BF16");
assert_eq!(v.data(), expected_bytes.as_slice(),
"{k}: BF16 bytes must be bit-exact through load+save");
}
}
#[test]
fn load_safetensors_missing_key_errors_loudly() {
use flodl::{FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(2, 2, dev).unwrap())
.tag("m")
.build().unwrap();
let w = vec![0.0_f32, 1.0, 2.0, 3.0];
let bytes = serialize_entries(&[
("m.weight", Dtype::F32, vec![2, 2], f32_le_bytes(&w)),
]);
let err = load_safetensors_into_graph(&graph, &bytes).unwrap_err().to_string();
assert!(err.contains("missing key"), "error must mention missing keys: {err}");
assert!(err.contains("m.bias"), "error must name the missing key: {err}");
}
#[test]
fn load_safetensors_rejects_integer_dtype() {
use flodl::{FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(1, 1, dev).unwrap())
.tag("m")
.build().unwrap();
let bias_i32: Vec<u8> = 1_i32.to_le_bytes().to_vec();
let w_bytes = f32_le_bytes(&[0.5_f32]);
let bytes = serialize_entries(&[
("m.weight", Dtype::F32, vec![1, 1], w_bytes),
("m.bias", Dtype::I32, vec![1], bias_i32),
]);
let err = load_safetensors_into_graph(&graph, &bytes).unwrap_err().to_string();
assert!(err.contains("unsupported safetensors dtype"),
"error must call out dtype: {err}");
assert!(err.contains("I32"), "error must name the offending dtype: {err}");
}
#[test]
fn bert_legacy_key_rename_rewrites_layernorm_suffixes() {
assert_eq!(
bert_legacy_key_rename("bert.embeddings.LayerNorm.gamma"),
"bert.embeddings.LayerNorm.weight",
);
assert_eq!(
bert_legacy_key_rename("bert.embeddings.LayerNorm.beta"),
"bert.embeddings.LayerNorm.bias",
);
assert_eq!(
bert_legacy_key_rename("bert.encoder.layer.3.attention.output.LayerNorm.gamma"),
"bert.encoder.layer.3.attention.output.LayerNorm.weight",
);
assert_eq!(
bert_legacy_key_rename("bert.encoder.layer.0.attention.self.query.weight"),
"bert.encoder.layer.0.attention.self.query.weight",
);
assert_eq!(
bert_legacy_key_rename("something.gamma"),
"something.gamma",
);
}
#[test]
fn bert_legacy_key_rename_retags_mlm_tied_bias() {
assert_eq!(
bert_legacy_key_rename("cls.predictions.bias"),
"cls.predictions.decoder.bias",
);
assert_eq!(
bert_legacy_key_rename("lm_head.bias"),
"lm_head.decoder.bias",
);
assert_eq!(
bert_legacy_key_rename("something.cls.predictions.bias"),
"something.cls.predictions.bias",
);
}
#[test]
fn load_safetensors_rename_collision_errors_loudly() {
use flodl::{FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(2, 2, dev).unwrap())
.tag("m")
.build().unwrap();
let w = f32_le_bytes(&[0.0, 1.0, 2.0, 3.0]);
let b = f32_le_bytes(&[0.1, 0.2]);
let bytes = serialize_entries(&[
("m.weight", Dtype::F32, vec![2, 2], w.clone()),
("m.LayerNorm.gamma", Dtype::F32, vec![2], b.clone()),
("m.LayerNorm.weight", Dtype::F32, vec![2], b),
]);
let err = load_safetensors_into_graph_with_rename(
&graph, &bytes, bert_legacy_key_rename,
).unwrap_err().to_string();
assert!(err.contains("rename collision"),
"error must identify the collision: {err}");
assert!(err.contains("LayerNorm.weight"),
"error must name the canonical key involved: {err}");
}
#[test]
fn f16_bits_to_f32_special_values() {
assert!(f16_bits_to_f32(0x7C00).is_infinite() && f16_bits_to_f32(0x7C00).is_sign_positive());
assert!(f16_bits_to_f32(0xFC00).is_infinite() && f16_bits_to_f32(0xFC00).is_sign_negative());
assert!(f16_bits_to_f32(0x7E00).is_nan());
let tiny = f16_bits_to_f32(0x0001);
assert!((tiny - 2.0_f32.powi(-24)).abs() < 1e-10, "tiny subnormal wrong: {tiny}");
assert!(f16_bits_to_f32(0x8000).is_sign_negative() && f16_bits_to_f32(0x8000) == 0.0);
}
#[test]
fn expected_from_graph_converts_slash_to_dot() {
use flodl::{FlowBuilder, Linear, Module};
let fb = FlowBuilder::new()
.through(Linear::new(4, 2).unwrap()).tag("bert.pooler.dense");
let graph = fb.build().unwrap();
let expected = expected_from_graph(&graph);
let keys: Vec<&str> = expected.iter().map(|e| e.key.as_str()).collect();
assert!(keys.contains(&"bert.pooler.dense.weight"),
"expected HF-dotted key missing, got {keys:?}");
assert!(keys.contains(&"bert.pooler.dense.bias"),
"expected HF-dotted key missing, got {keys:?}");
assert_eq!(expected.len(), graph.parameters().len());
}
#[test]
fn save_safetensors_load_roundtrip() {
use flodl::{FlowBuilder, Linear, Module, Variable};
let dev = Device::CPU;
let in_dim = 3_i64;
let out_dim = 2_i64;
let src = FlowBuilder::new()
.through(Linear::on_device(in_dim, out_dim, dev).unwrap())
.tag("my.linear")
.build().unwrap();
let src_weight: Vec<f32> = (0..(in_dim * out_dim) as usize)
.map(|i| 0.5 + i as f32 * 0.1).collect();
let src_bias: Vec<f32> = (0..out_dim as usize)
.map(|i| -1.0 + i as f32 * 0.25).collect();
for (k, p) in src.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let t = match hf.as_str() {
"my.linear.weight" => Tensor::from_f32(&src_weight, &[out_dim, in_dim], dev).unwrap(),
"my.linear.bias" => Tensor::from_f32(&src_bias, &[out_dim], dev).unwrap(),
other => panic!("unexpected key {other}"),
};
p.variable.set_data(t);
}
let bytes = save_safetensors_from_graph(&src).unwrap();
let dst = FlowBuilder::new()
.through(Linear::on_device(in_dim, out_dim, dev).unwrap())
.tag("my.linear")
.build().unwrap();
load_safetensors_into_graph(&dst, &bytes).unwrap();
let mut dw: Option<Vec<f32>> = None;
let mut db: Option<Vec<f32>> = None;
for (k, p) in dst.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let data = p.variable.data().to_f32_vec().unwrap();
match hf.as_str() {
"my.linear.weight" => dw = Some(data),
"my.linear.bias" => db = Some(data),
other => panic!("unexpected key {other}"),
}
}
assert_eq!(dw.unwrap(), src_weight);
assert_eq!(db.unwrap(), src_bias);
let _keep_alive: Vec<Variable> =
dst.parameters().into_iter().map(|p| p.variable).collect();
}
#[test]
fn save_safetensors_uses_hf_dotted_keys_and_le_f32() {
use flodl::{FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(2, 1, dev).unwrap())
.tag("encoder.layer.0.attention.output.dense")
.build().unwrap();
let w = vec![0.25_f32, -0.5];
let b = vec![1.0_f32];
for (k, p) in graph.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let t = match hf.as_str() {
"encoder.layer.0.attention.output.dense.weight"
=> Tensor::from_f32(&w, &[1, 2], dev).unwrap(),
"encoder.layer.0.attention.output.dense.bias"
=> Tensor::from_f32(&b, &[1], dev).unwrap(),
other => panic!("unexpected key {other}"),
};
p.variable.set_data(t);
}
let bytes = save_safetensors_from_graph(&graph).unwrap();
let st = SafeTensors::deserialize(&bytes).unwrap();
let names: HashSet<&str> = st.names().iter().map(|s| s.as_str()).collect();
assert!(names.contains("encoder.layer.0.attention.output.dense.weight"),
"expected HF-dotted key in output, got {names:?}");
assert!(names.contains("encoder.layer.0.attention.output.dense.bias"),
"expected HF-dotted key in output, got {names:?}");
let w_view = st.tensor("encoder.layer.0.attention.output.dense.weight").unwrap();
assert_eq!(w_view.dtype(), Dtype::F32);
assert_eq!(w_view.shape(), &[1_usize, 2]);
let w_back: Vec<f32> = w_view.data().chunks_exact(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap())).collect();
assert_eq!(w_back, w);
let b_view = st.tensor("encoder.layer.0.attention.output.dense.bias").unwrap();
let b_back: Vec<f32> = b_view.data().chunks_exact(4)
.map(|c| f32::from_le_bytes(c.try_into().unwrap())).collect();
assert_eq!(b_back, b);
}
#[test]
fn save_safetensors_dedups_shared_weights() {
use flodl::{FlowBuilder, Linear, Parameter};
let dev = Device::CPU;
let primary = Linear::on_device(2, 2, dev).unwrap();
let shared_weight = primary.weight.clone(); let tied_bias = Parameter::new(
Tensor::from_f32(&[0.0_f32, 0.0], &[2], dev).unwrap(),
"bias",
);
let tied = Linear::from_shared_weight(shared_weight, Some(tied_bias));
let graph = FlowBuilder::new()
.through(primary).tag("primary")
.through(tied).tag("tied")
.build().unwrap();
let bytes = save_safetensors_from_graph(&graph).unwrap();
let st = SafeTensors::deserialize(&bytes).unwrap();
let names: HashSet<&str> = st.names().iter().map(|s| s.as_str()).collect();
let weight_count = ["primary.weight", "tied.weight"].iter()
.filter(|k| names.contains(*k))
.count();
assert_eq!(
weight_count, 1,
"shared weight must ship exactly once, got {names:?}",
);
assert!(names.contains("primary.bias"), "primary bias missing in {names:?}");
assert!(names.contains("tied.bias"), "tied bias missing in {names:?}");
}
#[test]
fn save_safetensors_file_roundtrip() {
use flodl::{FlowBuilder, Linear};
let dev = Device::CPU;
let graph = FlowBuilder::new()
.through(Linear::on_device(2, 1, dev).unwrap())
.tag("m")
.build().unwrap();
let w = vec![0.1_f32, 0.2];
let b = vec![0.3_f32];
for (k, p) in graph.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let t = match hf.as_str() {
"m.weight" => Tensor::from_f32(&w, &[1, 2], dev).unwrap(),
"m.bias" => Tensor::from_f32(&b, &[1], dev).unwrap(),
other => panic!("unexpected {other}"),
};
p.variable.set_data(t);
}
let path = std::env::temp_dir()
.join(format!("flodl_hf_save_test_{}.safetensors", std::process::id()));
save_safetensors_file_from_graph(&graph, &path).unwrap();
let fresh = FlowBuilder::new()
.through(Linear::on_device(2, 1, dev).unwrap())
.tag("m")
.build().unwrap();
load_safetensors_file_into_graph(&fresh, &path).unwrap();
let _ = std::fs::remove_file(&path);
for (k, p) in fresh.named_parameters() {
let hf = hf_key_from_flodl_key(&k);
let data = p.variable.data().to_f32_vec().unwrap();
match hf.as_str() {
"m.weight" => assert_eq!(data, w),
"m.bias" => assert_eq!(data, b),
other => panic!("unexpected {other}"),
}
}
}
}