use super::{LoraAdapter, LoraConfig, LoraLayer};
use crate::error::TuneError;
use safetensors::Dtype;
use safetensors::tensor::{SafeTensors, TensorView, serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, PartialEq, Eq)]
struct PeftKey {
layer_idx: usize,
module: String,
matrix: LoraMatrix,
transposed: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LoraMatrix {
A,
B,
}
fn parse_peft_key(key: &str) -> Option<PeftKey> {
let key = key.strip_suffix(".weight").unwrap_or(key);
let (key, matrix, is_mlx) = if let Some(k) = key.strip_suffix(".lora_A") {
(k, LoraMatrix::A, false)
} else if let Some(k) = key.strip_suffix(".lora_B") {
(k, LoraMatrix::B, false)
} else if let Some(k) = key.strip_suffix(".lora_a") {
(k, LoraMatrix::A, true)
} else if let Some(k) = key.strip_suffix(".lora_b") {
(k, LoraMatrix::B, true)
} else {
return None;
};
let layers_marker = ".layers.";
let layers_pos = key.find(layers_marker)?;
let after_layers = &key[layers_pos + layers_marker.len()..];
let dot_pos = after_layers.find('.')?;
let layer_idx: usize = after_layers[..dot_pos].parse().ok()?;
let rest = &after_layers[dot_pos + 1..];
let module = rest.rsplit('.').next()?.to_string();
if module.is_empty() {
return None;
}
Some(PeftKey {
layer_idx,
module,
matrix,
transposed: is_mlx,
})
}
fn read_tensor_f32(
tensors: &SafeTensors<'_>,
name: &str,
) -> Result<(Vec<f32>, Vec<usize>), TuneError> {
let tensor = tensors
.tensor(name)
.map_err(|e| TuneError::Serialization(format!("failed to read tensor '{name}': {e}")))?;
let shape: Vec<usize> = tensor.shape().to_vec();
let data = tensor.data();
let values: Vec<f32> = match tensor.dtype() {
Dtype::F32 => {
if data.len() % 4 != 0 {
return Err(TuneError::Serialization(format!(
"tensor '{name}' f32 data length {} not aligned to 4 bytes",
data.len()
)));
}
data.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
Dtype::F16 => {
if data.len() % 2 != 0 {
return Err(TuneError::Serialization(format!(
"tensor '{name}' f16 data length {} not aligned to 2 bytes",
data.len()
)));
}
data.chunks_exact(2)
.map(|c| {
let bits = u16::from_le_bytes([c[0], c[1]]);
f16_to_f32(bits)
})
.collect()
}
Dtype::BF16 => {
if data.len() % 2 != 0 {
return Err(TuneError::Serialization(format!(
"tensor '{name}' bf16 data length {} not aligned to 2 bytes",
data.len()
)));
}
data.chunks_exact(2)
.map(|c| {
let bits = u16::from_le_bytes([c[0], c[1]]);
bf16_to_f32(bits)
})
.collect()
}
other => {
return Err(TuneError::Serialization(format!(
"tensor '{name}' has unsupported dtype {other:?}, expected F32/F16/BF16"
)));
}
};
if let Some((idx, value)) = values
.iter()
.enumerate()
.find(|(_, value)| !value.is_finite())
{
return Err(TuneError::Serialization(format!(
"tensor '{name}' contains non-finite value at index {idx}: {value}"
)));
}
Ok((values, shape))
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits >> 15) & 1) as u32;
let exp = ((bits >> 10) & 0x1f) as u32;
let frac = (bits & 0x3ff) as u32;
if exp == 0 {
if frac == 0 {
f32::from_bits(sign << 31)
} else {
let val = (frac as f32) / 1024.0 * (2.0f32).powi(-14);
if sign == 1 { -val } else { val }
}
} else if exp == 31 {
if frac == 0 {
f32::from_bits((sign << 31) | (0xff << 23))
} else {
f32::from_bits((sign << 31) | (0xff << 23) | (frac << 13))
}
} else {
let f32_exp = exp + 127 - 15;
f32::from_bits((sign << 31) | (f32_exp << 23) | (frac << 13))
}
}
fn bf16_to_f32(bits: u16) -> f32 {
f32::from_bits((bits as u32) << 16)
}
pub fn load_peft_safetensors(path: &Path) -> Result<LoraAdapter, TuneError> {
const MAX_LORA_SIZE: u64 = 10 * 1024 * 1024 * 1024;
let file_size = std::fs::metadata(path)
.map_err(|e| {
TuneError::Io(std::io::Error::new(
e.kind(),
format!("failed to read metadata for {}: {e}", path.display()),
))
})?
.len();
if file_size > MAX_LORA_SIZE {
return Err(TuneError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"LoRA file {} is {} bytes, exceeds maximum of {} bytes",
path.display(),
file_size,
MAX_LORA_SIZE
),
)));
}
let data = std::fs::read(path).map_err(|e| {
TuneError::Io(std::io::Error::new(
e.kind(),
format!("failed to read LoRA adapter from {}: {e}", path.display()),
))
})?;
let tensors = SafeTensors::deserialize(&data).map_err(|e| {
TuneError::Serialization(format!(
"failed to parse safetensors from {}: {e}",
path.display()
))
})?;
let names: Vec<String> = tensors.names().into_iter().map(String::from).collect();
let mut a_tensors: HashMap<(usize, String), (Vec<f32>, Vec<usize>)> = HashMap::new();
let mut b_tensors: HashMap<(usize, String), (Vec<f32>, Vec<usize>)> = HashMap::new();
let mut target_modules = std::collections::BTreeSet::new();
for name in &names {
if let Some(peft_key) = parse_peft_key(name) {
let key = (peft_key.layer_idx, peft_key.module.clone());
let (data, shape) = read_tensor_f32(&tensors, name)?;
target_modules.insert(peft_key.module.clone());
let (data, shape) = if peft_key.transposed && shape.len() == 2 {
let (rows, cols) = (shape[0], shape[1]);
let mut transposed = vec![0.0f32; data.len()];
for r in 0..rows {
for c in 0..cols {
transposed[c * rows + r] = data[r * cols + c];
}
}
(transposed, vec![cols, rows])
} else {
(data, shape)
};
match peft_key.matrix {
LoraMatrix::A => {
a_tensors.insert(key, (data, shape));
}
LoraMatrix::B => {
b_tensors.insert(key, (data, shape));
}
}
}
}
if a_tensors.is_empty() && b_tensors.is_empty() {
return Err(TuneError::Serialization(
"LoRA adapter contains no lora_A/lora_B tensors".to_string(),
));
}
let mut layers: HashMap<(usize, String), LoraLayer> = HashMap::new();
let mut rank: Option<usize> = None;
for (key, (a_data, a_shape)) in &a_tensors {
let (b_data, b_shape) = b_tensors.get(key).ok_or_else(|| {
TuneError::Serialization(format!(
"LoRA adapter has lora_A for layer {} module '{}' but no matching lora_B",
key.0, key.1
))
})?;
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(TuneError::Serialization(format!(
"LoRA matrices for layer {} module '{}' must be 2D, got A={:?} B={:?}",
key.0, key.1, a_shape, b_shape
)));
}
let a_rank = a_shape[0];
let d_in = a_shape[1];
let d_out = b_shape[0];
let b_rank = b_shape[1];
if a_rank != b_rank {
return Err(TuneError::Serialization(format!(
"LoRA rank mismatch for layer {} module '{}': A rank={}, B rank={}",
key.0, key.1, a_rank, b_rank
)));
}
match rank {
None => rank = Some(a_rank),
Some(r) if r != a_rank => {
return Err(TuneError::Serialization(format!(
"inconsistent LoRA ranks: first seen rank={}, layer {} module '{}' has rank={}",
r, key.0, key.1, a_rank
)));
}
_ => {}
}
let expected_a = a_rank.checked_mul(d_in).ok_or_else(|| {
TuneError::Serialization(format!(
"LoRA A shape overflow for layer {} module '{}': {} * {}",
key.0, key.1, a_rank, d_in
))
})?;
if a_data.len() != expected_a {
return Err(TuneError::Serialization(format!(
"LoRA A data size mismatch for layer {} module '{}': expected {}, got {}",
key.0,
key.1,
expected_a,
a_data.len()
)));
}
let expected_b = d_out.checked_mul(b_rank).ok_or_else(|| {
TuneError::Serialization(format!(
"LoRA B shape overflow for layer {} module '{}': {} * {}",
key.0, key.1, d_out, b_rank
))
})?;
if b_data.len() != expected_b {
return Err(TuneError::Serialization(format!(
"LoRA B data size mismatch for layer {} module '{}': expected {}, got {}",
key.0,
key.1,
expected_b,
b_data.len()
)));
}
layers.insert(
key.clone(),
LoraLayer {
a: a_data.clone(),
b: b_data.clone(),
d_in,
d_out,
rank: a_rank,
},
);
}
for key in b_tensors.keys() {
if !a_tensors.contains_key(key) {
return Err(TuneError::Serialization(format!(
"LoRA adapter has lora_B for layer {} module '{}' but no matching lora_A",
key.0, key.1
)));
}
}
let rank = rank.unwrap_or(0);
Ok(LoraAdapter {
config: LoraConfig {
rank,
alpha: rank as f32, target_modules: target_modules.into_iter().collect(),
},
layers,
})
}
pub fn save_peft_safetensors(adapter: &LoraAdapter, path: &Path) -> Result<(), TuneError> {
let mut byte_data: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
for ((layer_idx, module), layer) in &adapter.layers {
let block = match module.as_str() {
"q_proj" | "k_proj" | "v_proj" | "o_proj" => "self_attn",
"gate_proj" | "up_proj" | "down_proj" => "mlp",
_ => "self_attn",
};
let a_key =
format!("base_model.model.model.layers.{layer_idx}.{block}.{module}.lora_A.weight");
let b_key =
format!("base_model.model.model.layers.{layer_idx}.{block}.{module}.lora_B.weight");
let a_bytes: Vec<u8> = layer.a.iter().flat_map(|f| f.to_le_bytes()).collect();
let b_bytes: Vec<u8> = layer.b.iter().flat_map(|f| f.to_le_bytes()).collect();
byte_data.push((a_key, vec![layer.rank, layer.d_in], a_bytes));
byte_data.push((b_key, vec![layer.d_out, layer.rank], b_bytes));
}
let mut tensor_views: HashMap<String, TensorView<'_>> = HashMap::new();
for (name, shape, bytes) in &byte_data {
let view = TensorView::new(Dtype::F32, shape.clone(), bytes).map_err(|e| {
TuneError::Serialization(format!("failed to create tensor view for '{name}': {e}"))
})?;
tensor_views.insert(name.clone(), view);
}
let mut metadata_map = HashMap::new();
metadata_map.insert("rank".to_string(), adapter.config.rank.to_string());
metadata_map.insert("alpha".to_string(), adapter.config.alpha.to_string());
metadata_map.insert(
"target_modules".to_string(),
adapter.config.target_modules.join(","),
);
let metadata = Some(metadata_map);
let bytes = serialize(&tensor_views, &metadata)
.map_err(|e| TuneError::Serialization(format!("failed to serialize LoRA adapter: {e}")))?;
std::fs::write(path, bytes).map_err(|e| {
TuneError::Io(std::io::Error::new(
e.kind(),
format!("failed to write LoRA adapter to {}: {e}", path.display()),
))
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_peft_key_self_attn() {
let key = "base_model.model.model.layers.5.self_attn.q_proj.lora_A.weight";
let parsed = parse_peft_key(key).unwrap();
assert_eq!(parsed.layer_idx, 5);
assert_eq!(parsed.module, "q_proj");
assert_eq!(parsed.matrix, LoraMatrix::A);
}
#[test]
fn test_parse_peft_key_mlp() {
let key = "base_model.model.model.layers.12.mlp.gate_proj.lora_B.weight";
let parsed = parse_peft_key(key).unwrap();
assert_eq!(parsed.layer_idx, 12);
assert_eq!(parsed.module, "gate_proj");
assert_eq!(parsed.matrix, LoraMatrix::B);
}
#[test]
fn test_parse_peft_key_simple_prefix() {
let key = "model.layers.0.self_attn.v_proj.lora_A.weight";
let parsed = parse_peft_key(key).unwrap();
assert_eq!(parsed.layer_idx, 0);
assert_eq!(parsed.module, "v_proj");
assert_eq!(parsed.matrix, LoraMatrix::A);
}
#[test]
fn test_parse_peft_key_no_weight_suffix() {
let key = "base_model.model.model.layers.3.mlp.up_proj.lora_B";
let parsed = parse_peft_key(key).unwrap();
assert_eq!(parsed.layer_idx, 3);
assert_eq!(parsed.module, "up_proj");
assert_eq!(parsed.matrix, LoraMatrix::B);
}
#[test]
fn test_parse_peft_key_rejects_non_lora() {
assert!(parse_peft_key("model.layers.0.self_attn.q_proj.weight").is_none());
assert!(parse_peft_key("some_random_tensor").is_none());
assert!(parse_peft_key("").is_none());
}
#[test]
fn test_parse_mlx_key_lowercase() {
let key = "model.layers.3.self_attn.q_proj.lora_a";
let parsed = parse_peft_key(key).unwrap();
assert_eq!(parsed.layer_idx, 3);
assert_eq!(parsed.module, "q_proj");
assert_eq!(parsed.matrix, LoraMatrix::A);
assert!(parsed.transposed);
}
#[test]
fn test_parse_mlx_key_mlp() {
let key = "model.layers.7.mlp.down_proj.lora_b";
let parsed = parse_peft_key(key).unwrap();
assert_eq!(parsed.layer_idx, 7);
assert_eq!(parsed.module, "down_proj");
assert_eq!(parsed.matrix, LoraMatrix::B);
assert!(parsed.transposed);
}
#[test]
fn test_peft_key_not_transposed() {
let key = "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight";
let parsed = parse_peft_key(key).unwrap();
assert!(!parsed.transposed);
}
#[test]
fn test_parse_peft_key_all_target_modules() {
let modules = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
];
for (module, block) in modules.iter().zip(
[
"self_attn",
"self_attn",
"self_attn",
"self_attn",
"mlp",
"mlp",
"mlp",
]
.iter(),
) {
let key = format!("base_model.model.model.layers.7.{block}.{module}.lora_A.weight");
let parsed = parse_peft_key(&key).unwrap();
assert_eq!(parsed.layer_idx, 7);
assert_eq!(parsed.module, *module);
assert_eq!(parsed.matrix, LoraMatrix::A);
}
}
#[test]
fn test_f16_to_f32_roundtrip() {
assert!((f16_to_f32(0x3C00) - 1.0).abs() < 1e-6);
assert!((f16_to_f32(0xBC00) - (-1.0)).abs() < 1e-6);
assert_eq!(f16_to_f32(0x0000), 0.0);
assert!((f16_to_f32(0x3800) - 0.5).abs() < 1e-6);
}
#[test]
fn test_bf16_to_f32() {
assert_eq!(bf16_to_f32(0x3F80), 1.0);
assert_eq!(bf16_to_f32(0x0000), 0.0);
}
fn write_test_peft_safetensors(path: &std::path::Path, rank: usize, d_in: usize, d_out: usize) {
use safetensors::Dtype;
use safetensors::tensor::{TensorView, serialize};
use std::collections::HashMap;
let a_data: Vec<f32> = (0..rank * d_in).map(|i| (i as f32) * 0.01).collect();
let b_data: Vec<f32> = (0..d_out * rank).map(|i| (i as f32) * 0.1).collect();
let a2_data: Vec<f32> = (0..rank * d_in).map(|i| (i as f32) * -0.01).collect();
let b2_data: Vec<f32> = (0..d_out * rank).map(|i| (i as f32) * -0.1).collect();
let a_bytes: Vec<u8> = a_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let b_bytes: Vec<u8> = b_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let a2_bytes: Vec<u8> = a2_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let b2_bytes: Vec<u8> = b2_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let mut tensors = HashMap::new();
tensors.insert(
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight".to_string(),
TensorView::new(Dtype::F32, vec![rank, d_in], &a_bytes).unwrap(),
);
tensors.insert(
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight".to_string(),
TensorView::new(Dtype::F32, vec![d_out, rank], &b_bytes).unwrap(),
);
tensors.insert(
"base_model.model.model.layers.2.mlp.gate_proj.lora_A.weight".to_string(),
TensorView::new(Dtype::F32, vec![rank, d_in], &a2_bytes).unwrap(),
);
tensors.insert(
"base_model.model.model.layers.2.mlp.gate_proj.lora_B.weight".to_string(),
TensorView::new(Dtype::F32, vec![d_out, rank], &b2_bytes).unwrap(),
);
let bytes = serialize(&tensors, &None).unwrap();
std::fs::write(path, bytes).unwrap();
}
#[test]
fn test_load_peft_safetensors_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("adapter.safetensors");
let rank = 4;
let d_in = 8;
let d_out = 8;
write_test_peft_safetensors(&path, rank, d_in, d_out);
let adapter = load_peft_safetensors(&path).unwrap();
assert_eq!(adapter.config.rank, rank);
assert_eq!(adapter.config.target_modules.len(), 2);
assert!(
adapter
.config
.target_modules
.contains(&"q_proj".to_string())
);
assert!(
adapter
.config
.target_modules
.contains(&"gate_proj".to_string())
);
assert_eq!(adapter.layers.len(), 2);
assert!(adapter.layers.contains_key(&(0, "q_proj".to_string())));
assert!(adapter.layers.contains_key(&(2, "gate_proj".to_string())));
let q_lora = &adapter.layers[&(0, "q_proj".to_string())];
assert_eq!(q_lora.rank, rank);
assert_eq!(q_lora.d_in, d_in);
assert_eq!(q_lora.d_out, d_out);
assert_eq!(q_lora.a.len(), rank * d_in);
assert_eq!(q_lora.b.len(), d_out * rank);
assert!((q_lora.a[0] - 0.0).abs() < 1e-6);
assert!((q_lora.a[1] - 0.01).abs() < 1e-6);
assert!((q_lora.a[2] - 0.02).abs() < 1e-6);
let g_lora = &adapter.layers[&(2, "gate_proj".to_string())];
assert_eq!(g_lora.rank, rank);
assert!((g_lora.a[1] - (-0.01)).abs() < 1e-6);
}
#[test]
fn test_load_mlx_safetensors_transposed() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("adapters.safetensors");
use safetensors::Dtype;
use safetensors::tensor::{TensorView, serialize};
let a_data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
];
let b_data: Vec<f32> = vec![
0.1, 0.2, 0.3, 0.4, 0.5, 0.6,
];
let a_bytes: Vec<u8> = a_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let b_bytes: Vec<u8> = b_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let mut tensors = std::collections::HashMap::new();
tensors.insert(
"model.layers.0.self_attn.q_proj.lora_a".to_string(),
TensorView::new(Dtype::F32, vec![4, 2], &a_bytes).unwrap(),
);
tensors.insert(
"model.layers.0.self_attn.q_proj.lora_b".to_string(),
TensorView::new(Dtype::F32, vec![2, 3], &b_bytes).unwrap(),
);
let bytes = serialize(&tensors, &None).unwrap();
std::fs::write(&path, bytes).unwrap();
let adapter = load_peft_safetensors(&path).unwrap();
assert_eq!(adapter.config.rank, 2);
let lora = &adapter.layers[&(0, "q_proj".to_string())];
assert_eq!(lora.rank, 2);
assert_eq!(lora.d_in, 4);
assert_eq!(lora.d_out, 3);
assert!((lora.a[0] - 1.0).abs() < 1e-6); assert!((lora.a[1] - 3.0).abs() < 1e-6);
assert!((lora.a[2] - 5.0).abs() < 1e-6);
assert!((lora.a[3] - 7.0).abs() < 1e-6);
assert!((lora.a[4] - 2.0).abs() < 1e-6); assert!((lora.a[5] - 4.0).abs() < 1e-6);
}
#[test]
fn test_load_peft_safetensors_apply_correctness() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("adapter.safetensors");
use safetensors::Dtype;
use safetensors::tensor::{TensorView, serialize};
let a_data: Vec<f32> = vec![0.0, 1.0]; let b_data: Vec<f32> = vec![1.0, 0.0]; let a_bytes: Vec<u8> = a_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let b_bytes: Vec<u8> = b_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let mut tensors = std::collections::HashMap::new();
tensors.insert(
"model.layers.0.self_attn.v_proj.lora_A.weight".to_string(),
TensorView::new(Dtype::F32, vec![1, 2], &a_bytes).unwrap(),
);
tensors.insert(
"model.layers.0.self_attn.v_proj.lora_B.weight".to_string(),
TensorView::new(Dtype::F32, vec![2, 1], &b_bytes).unwrap(),
);
let bytes = serialize(&tensors, &None).unwrap();
std::fs::write(&path, bytes).unwrap();
let adapter = load_peft_safetensors(&path).unwrap();
assert_eq!(adapter.config.rank, 1);
let lora = &adapter.layers[&(0, "v_proj".to_string())];
let x = [3.0f32, 5.0];
let mut output = [10.0, 20.0];
super::super::apply_lora(lora, adapter.config.scale(), &x, &mut output);
assert!((output[0] - 15.0).abs() < 1e-6);
assert!((output[1] - 20.0).abs() < 1e-6);
}
fn write_raw_safetensors(path: &std::path::Path, header: &str, data: &[u8]) {
let mut bytes = Vec::new();
bytes.extend_from_slice(&(header.len() as u64).to_le_bytes());
bytes.extend_from_slice(header.as_bytes());
bytes.extend_from_slice(data);
std::fs::write(path, bytes).unwrap();
}
#[test]
fn test_rejects_truncated_safetensors_header() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("truncated.safetensors");
std::fs::write(&path, 128u64.to_le_bytes()).unwrap();
let err = load_peft_safetensors(&path).unwrap_err();
assert!(err.to_string().contains("failed to parse safetensors"));
}
#[test]
fn test_rejects_file_without_lora_tensors() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("empty.safetensors");
let header = r#"{"metadata.weight":{"dtype":"F32","shape":[1],"data_offsets":[0,4]}}"#;
write_raw_safetensors(&path, header, &1.0f32.to_le_bytes());
let err = load_peft_safetensors(&path).unwrap_err();
assert!(err.to_string().contains("no lora_A/lora_B tensors"));
}
#[test]
fn test_rejects_shape_product_overflow() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("overflow.safetensors");
let max = usize::MAX;
let header = format!(
r#"{{"model.layers.0.self_attn.q_proj.lora_A.weight":{{"dtype":"F32","shape":[{max},2],"data_offsets":[0,0]}},"model.layers.0.self_attn.q_proj.lora_B.weight":{{"dtype":"F32","shape":[1,{max}],"data_offsets":[0,0]}}}}"#
);
write_raw_safetensors(&path, &header, &[]);
let err = load_peft_safetensors(&path).unwrap_err();
assert!(
err.to_string().contains("shape overflow")
|| err.to_string().contains("failed to parse safetensors")
);
}
#[test]
fn test_rejects_non_finite_tensor_values() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("nan.safetensors");
use safetensors::Dtype;
use safetensors::tensor::{TensorView, serialize};
let a_data: Vec<f32> = vec![f32::NAN, 1.0];
let b_data: Vec<f32> = vec![1.0, 2.0];
let a_bytes: Vec<u8> = a_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let b_bytes: Vec<u8> = b_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let mut tensors = std::collections::HashMap::new();
tensors.insert(
"model.layers.0.self_attn.q_proj.lora_A.weight".to_string(),
TensorView::new(Dtype::F32, vec![1, 2], &a_bytes).unwrap(),
);
tensors.insert(
"model.layers.0.self_attn.q_proj.lora_B.weight".to_string(),
TensorView::new(Dtype::F32, vec![2, 1], &b_bytes).unwrap(),
);
let bytes = serialize(&tensors, &None).unwrap();
std::fs::write(&path, bytes).unwrap();
let err = load_peft_safetensors(&path).unwrap_err();
assert!(err.to_string().contains("non-finite"));
}
#[test]
fn test_save_load_round_trip() {
use tempfile::NamedTempFile;
let rank = 4;
let d_in = 8;
let d_out = 16;
let config = LoraConfig {
rank,
alpha: rank as f32,
target_modules: vec!["q_proj".to_string(), "gate_proj".to_string()],
};
let a_data: Vec<f32> = (0..rank * d_in).map(|i| i as f32 * 0.01).collect();
let b_data: Vec<f32> = (0..d_out * rank).map(|i| i as f32 * 0.1).collect();
let a2_data: Vec<f32> = (0..rank * d_in).map(|i| i as f32 * -0.01).collect();
let b2_data: Vec<f32> = (0..d_out * rank).map(|i| i as f32 * -0.1).collect();
let mut layers = HashMap::new();
layers.insert(
(0usize, "q_proj".to_string()),
LoraLayer {
a: a_data.clone(),
b: b_data.clone(),
d_in,
d_out,
rank,
},
);
layers.insert(
(2usize, "gate_proj".to_string()),
LoraLayer {
a: a2_data.clone(),
b: b2_data.clone(),
d_in,
d_out,
rank,
},
);
let adapter = LoraAdapter::new(config, layers);
let temp = NamedTempFile::new().unwrap();
save_peft_safetensors(&adapter, temp.path()).unwrap();
let loaded = load_peft_safetensors(temp.path()).unwrap();
assert_eq!(loaded.layers.len(), adapter.layers.len());
assert_eq!(loaded.config.rank, adapter.config.rank);
assert_eq!(loaded.config.alpha, adapter.config.rank as f32);
for (key, orig) in &adapter.layers {
let got = loaded
.layers
.get(key)
.expect("layer key missing after round-trip");
assert_eq!(got.rank, orig.rank);
assert_eq!(got.d_in, orig.d_in);
assert_eq!(got.d_out, orig.d_out);
assert_eq!(got.a.len(), orig.a.len());
assert_eq!(got.b.len(), orig.b.len());
for (g, w) in got.a.iter().zip(&orig.a) {
assert!((g - w).abs() < f32::EPSILON, "A mismatch: {g} vs {w}");
}
for (g, w) in got.b.iter().zip(&orig.b) {
assert!((g - w).abs() < f32::EPSILON, "B mismatch: {g} vs {w}");
}
}
}
}