use std::collections::HashMap;
use std::path::Path;
use crate::weight_loader::ModelWeights;
#[derive(Debug, Clone)]
pub struct LoraAdapter {
pub adapters: HashMap<String, LoraLayer>,
}
#[derive(Debug, Clone)]
pub struct LoraLayer {
pub lora_a: Vec<f32>,
pub lora_b: Vec<f32>,
pub rank: usize,
pub alpha: f32,
pub in_features: usize,
pub out_features: usize,
}
impl LoraLayer {
pub fn compute_delta(&self) -> Vec<f32> {
let scale = self.alpha / self.rank as f32;
let mut delta = vec![0.0f32; self.out_features * self.in_features];
for i in 0..self.out_features {
for k in 0..self.rank {
let b_ik = self.lora_b[i * self.rank + k];
for j in 0..self.in_features {
delta[i * self.in_features + j] += b_ik * self.lora_a[k * self.in_features + j];
}
}
}
for v in delta.iter_mut() {
*v *= scale;
}
delta
}
}
#[derive(Debug, thiserror::Error)]
pub enum LoraError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("SafeTensors parse error: {0}")]
Parse(String),
#[error("missing lora_A for layer: {0}")]
MissingLoraA(String),
#[error("missing lora_B for layer: {0}")]
MissingLoraB(String),
}
type ParsedTensors = HashMap<String, (Vec<f32>, Vec<usize>)>;
fn parse_safetensors(data: &[u8]) -> Result<ParsedTensors, LoraError> {
if data.len() < 8 {
return Err(LoraError::Parse(
"file too small for SafeTensors header".into(),
));
}
let header_len = u64::from_le_bytes(data[0..8].try_into().unwrap()) as usize;
let header_end = 8 + header_len;
if header_end > data.len() {
return Err(LoraError::Parse(format!(
"header length {header_len} exceeds file size {}",
data.len()
)));
}
let header_json = &data[8..header_end];
let tensor_data_base = header_end;
let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(header_json)
.map_err(|e| LoraError::Parse(format!("JSON parse error: {e}")))?;
let mut result = HashMap::new();
for (key, value) in &raw {
if key == "__metadata__" {
continue;
}
let dtype = value
.get("dtype")
.and_then(|v| v.as_str())
.ok_or_else(|| LoraError::Parse(format!("missing dtype for tensor {key}")))?;
let shape: Vec<usize> = value
.get("shape")
.and_then(|v| v.as_array())
.ok_or_else(|| LoraError::Parse(format!("missing shape for tensor {key}")))?
.iter()
.map(|v| {
v.as_u64()
.map(|n| n as usize)
.ok_or_else(|| LoraError::Parse(format!("invalid shape element in {key}")))
})
.collect::<Result<Vec<_>, _>>()?;
let offsets = value
.get("data_offsets")
.and_then(|v| v.as_array())
.ok_or_else(|| LoraError::Parse(format!("missing data_offsets for tensor {key}")))?;
if offsets.len() != 2 {
return Err(LoraError::Parse(format!(
"data_offsets must have 2 elements for tensor {key}"
)));
}
let start = offsets[0]
.as_u64()
.ok_or_else(|| LoraError::Parse(format!("invalid data_offsets[0] for {key}")))?
as usize;
let end = offsets[1]
.as_u64()
.ok_or_else(|| LoraError::Parse(format!("invalid data_offsets[1] for {key}")))?
as usize;
let abs_start = tensor_data_base + start;
let abs_end = tensor_data_base + end;
if abs_end > data.len() {
return Err(LoraError::Parse(format!(
"tensor {key} data range [{start},{end}) exceeds file size"
)));
}
let raw_bytes = &data[abs_start..abs_end];
let floats = bytes_to_f32(raw_bytes, dtype)
.map_err(|e| LoraError::Parse(format!("dtype conversion for {key}: {e}")))?;
result.insert(key.clone(), (floats, shape));
}
Ok(result)
}
fn bytes_to_f32(data: &[u8], dtype: &str) -> Result<Vec<f32>, String> {
match dtype {
"F32" => {
if !data.len().is_multiple_of(4) {
return Err(format!("F32 data length {} not divisible by 4", data.len()));
}
Ok(data
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
.collect())
}
"F16" => {
if !data.len().is_multiple_of(2) {
return Err(format!("F16 data length {} not divisible by 2", data.len()));
}
Ok(data
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes(b.try_into().unwrap());
f16_to_f32(bits)
})
.collect())
}
"BF16" => {
if !data.len().is_multiple_of(2) {
return Err(format!(
"BF16 data length {} not divisible by 2",
data.len()
));
}
Ok(data
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes(b.try_into().unwrap());
bf16_to_f32(bits)
})
.collect())
}
other => Err(format!("unsupported dtype for LoRA: {other}")),
}
}
#[inline]
fn f16_to_f32(bits: u16) -> f32 {
let sign = (bits >> 15) as u32;
let exp = ((bits >> 10) & 0x1f) as u32;
let mant = (bits & 0x3ff) as u32;
let f32_bits = if exp == 0 {
if mant == 0 {
sign << 31
} else {
let mut e = 127 - 14;
let mut m = mant;
while m & 0x400 == 0 {
m <<= 1;
e -= 1;
}
(sign << 31) | (e << 23) | ((m & 0x3ff) << 13)
}
} else if exp == 0x1f {
(sign << 31) | (0xff << 23) | (mant << 13)
} else {
(sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13)
};
f32::from_bits(f32_bits)
}
#[inline]
fn bf16_to_f32(bits: u16) -> f32 {
f32::from_bits((bits as u32) << 16)
}
#[derive(Debug, PartialEq)]
enum LoraKeyKind {
A,
B,
Alpha,
}
fn parse_lora_key(key: &str) -> Option<(String, LoraKeyKind)> {
let trimmed = if let Some(rest) = key.strip_prefix("base_model.model.") {
rest
} else {
key
};
if let Some(base) = trimmed.strip_suffix(".lora_A.weight") {
return Some((format!("model.{base}.weight"), LoraKeyKind::A));
}
if let Some(base) = trimmed.strip_suffix(".lora_B.weight") {
return Some((format!("model.{base}.weight"), LoraKeyKind::B));
}
if let Some(base) = trimmed.strip_suffix(".lora_alpha") {
return Some((format!("model.{base}.weight"), LoraKeyKind::Alpha));
}
if trimmed.ends_with(".alpha") {
let base = trimmed.strip_suffix(".alpha").unwrap();
return Some((format!("model.{base}.weight"), LoraKeyKind::Alpha));
}
None
}
pub fn load_lora(path: impl AsRef<Path>) -> Result<LoraAdapter, LoraError> {
let data = std::fs::read(path.as_ref())?;
load_lora_from_bytes(&data)
}
pub fn load_lora_from_bytes(data: &[u8]) -> Result<LoraAdapter, LoraError> {
let tensors = parse_safetensors(data)?;
let mut a_map: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
let mut b_map: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
let mut alpha_map: HashMap<String, f32> = HashMap::new();
for (key, (data_vec, shape)) in tensors {
match parse_lora_key(&key) {
Some((base_name, LoraKeyKind::A)) => {
a_map.insert(base_name, (data_vec, shape));
}
Some((base_name, LoraKeyKind::B)) => {
b_map.insert(base_name, (data_vec, shape));
}
Some((base_name, LoraKeyKind::Alpha)) => {
if let Some(&v) = data_vec.first() {
alpha_map.insert(base_name, v);
}
}
None => {
}
}
}
let mut adapters = HashMap::new();
for (base_name, (a_data, a_shape)) in &a_map {
let (b_data, b_shape) = b_map
.get(base_name)
.ok_or_else(|| LoraError::MissingLoraB(base_name.clone()))?;
if a_shape.len() != 2 {
return Err(LoraError::Parse(format!(
"lora_A for {base_name} must be 2-D, got {} dims",
a_shape.len()
)));
}
let rank = a_shape[0];
let in_features = a_shape[1];
if b_shape.len() != 2 {
return Err(LoraError::Parse(format!(
"lora_B for {base_name} must be 2-D, got {} dims",
b_shape.len()
)));
}
let out_features = b_shape[0];
if b_shape[1] != rank {
return Err(LoraError::Parse(format!(
"rank mismatch for {base_name}: lora_A rank={rank}, lora_B inner dim={}",
b_shape[1]
)));
}
let alpha = alpha_map.get(base_name).copied().unwrap_or(rank as f32);
adapters.insert(
base_name.clone(),
LoraLayer {
lora_a: a_data.clone(),
lora_b: b_data.clone(),
rank,
alpha,
in_features,
out_features,
},
);
}
for base_name in b_map.keys() {
if !a_map.contains_key(base_name) {
return Err(LoraError::MissingLoraA(base_name.clone()));
}
}
Ok(LoraAdapter { adapters })
}
pub fn merge_lora(base: &mut ModelWeights, lora: &LoraAdapter) {
for (weight_name, layer) in &lora.adapters {
if let Some(base_weight) = base.tensors.get_mut(weight_name) {
let delta = layer.compute_delta();
let len = base_weight.len().min(delta.len());
for (w, d) in base_weight[..len].iter_mut().zip(delta[..len].iter()) {
*w += d;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_delta_identity() {
let layer = LoraLayer {
lora_a: vec![1.0, 0.0, 0.0, 1.0],
lora_b: vec![1.0, 0.0, 0.0, 1.0],
rank: 2,
alpha: 2.0,
in_features: 2,
out_features: 2,
};
let delta = layer.compute_delta();
assert_eq!(delta, vec![1.0, 0.0, 0.0, 1.0]);
}
#[test]
fn compute_delta_known_values() {
let layer = LoraLayer {
lora_a: vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
lora_b: vec![1.0, 2.0, 3.0, 4.0],
rank: 2,
alpha: 4.0,
in_features: 3,
out_features: 2,
};
let delta = layer.compute_delta();
assert_eq!(delta, vec![2.0, 4.0, 2.0, 6.0, 8.0, 6.0]);
}
#[test]
fn compute_delta_zero_b() {
let layer = LoraLayer {
lora_a: vec![1.0, 2.0, 3.0, 4.0],
lora_b: vec![0.0, 0.0, 0.0, 0.0],
rank: 2,
alpha: 1.0,
in_features: 2,
out_features: 2,
};
let delta = layer.compute_delta();
assert_eq!(delta, vec![0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn merge_lora_applies_delta() {
let mut base = ModelWeights {
tensors: {
let mut m = HashMap::new();
m.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
vec![1.0f32, 0.0, 0.0, 1.0],
);
m
},
};
let lora = LoraAdapter {
adapters: {
let mut m = HashMap::new();
m.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
LoraLayer {
lora_a: vec![1.0, 0.0, 0.0, 1.0],
lora_b: vec![0.5, 0.0, 0.0, 0.5],
rank: 2,
alpha: 2.0, in_features: 2,
out_features: 2,
},
);
m
},
};
merge_lora(&mut base, &lora);
let w = &base.tensors["model.layers.0.self_attn.q_proj.weight"];
assert_eq!(*w, vec![1.5f32, 0.0, 0.0, 1.5]);
}
#[test]
fn merge_lora_skips_missing_base_weight() {
let mut base = ModelWeights {
tensors: HashMap::new(),
};
let lora = LoraAdapter {
adapters: {
let mut m = HashMap::new();
m.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
LoraLayer {
lora_a: vec![1.0],
lora_b: vec![1.0],
rank: 1,
alpha: 1.0,
in_features: 1,
out_features: 1,
},
);
m
},
};
merge_lora(&mut base, &lora);
assert!(base.tensors.is_empty());
}
fn build_safetensors(header_json: &str, tensor_data: &[u8]) -> Vec<u8> {
let header_bytes = header_json.as_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(header_bytes);
buf.extend_from_slice(tensor_data);
buf
}
#[test]
fn parse_safetensors_f32() {
let floats: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let raw: Vec<u8> = floats.iter().flat_map(|f| f.to_le_bytes()).collect();
let header = r#"{"w": {"dtype":"F32","shape":[2,2],"data_offsets":[0,16]}}"#;
let buf = build_safetensors(header, &raw);
let tensors = parse_safetensors(&buf).unwrap();
assert_eq!(tensors["w"].0, vec![1.0f32, 2.0, 3.0, 4.0]);
assert_eq!(tensors["w"].1, vec![2usize, 2]);
}
#[test]
fn parse_safetensors_bf16() {
let one_bf16: u16 = 0x3F80_u16;
let raw: Vec<u8> = vec![
(one_bf16 & 0xff) as u8,
(one_bf16 >> 8) as u8,
(one_bf16 & 0xff) as u8,
(one_bf16 >> 8) as u8,
];
let header = r#"{"w": {"dtype":"BF16","shape":[2],"data_offsets":[0,4]}}"#;
let buf = build_safetensors(header, &raw);
let tensors = parse_safetensors(&buf).unwrap();
assert_eq!(tensors["w"].0, vec![1.0f32, 1.0f32]);
}
#[test]
fn parse_safetensors_metadata_skipped() {
let header = r#"{"__metadata__":{"version":"1"},"w":{"dtype":"F32","shape":[1],"data_offsets":[0,4]}}"#;
let raw: Vec<u8> = 1.0f32.to_le_bytes().to_vec();
let buf = build_safetensors(header, &raw);
let tensors = parse_safetensors(&buf).unwrap();
assert!(!tensors.contains_key("__metadata__"));
assert!(tensors.contains_key("w"));
}
#[test]
fn parse_safetensors_too_small() {
let result = parse_safetensors(&[0u8; 4]);
assert!(matches!(result, Err(LoraError::Parse(_))));
}
#[test]
fn parse_lora_key_peft_convention() {
let key = "base_model.model.layers.0.self_attn.q_proj.lora_A.weight";
let (base, kind) = parse_lora_key(key).unwrap();
assert_eq!(base, "model.layers.0.self_attn.q_proj.weight");
assert_eq!(kind, LoraKeyKind::A);
}
#[test]
fn parse_lora_key_b_matrix() {
let key = "base_model.model.layers.0.self_attn.q_proj.lora_B.weight";
let (base, kind) = parse_lora_key(key).unwrap();
assert_eq!(base, "model.layers.0.self_attn.q_proj.weight");
assert_eq!(kind, LoraKeyKind::B);
}
#[test]
fn parse_lora_key_alpha_tensor() {
let key = "base_model.model.layers.0.self_attn.q_proj.lora_alpha";
let (base, kind) = parse_lora_key(key).unwrap();
assert_eq!(base, "model.layers.0.self_attn.q_proj.weight");
assert_eq!(kind, LoraKeyKind::Alpha);
}
#[test]
fn parse_lora_key_unknown_returns_none() {
assert!(parse_lora_key("model.embed_tokens.weight").is_none());
assert!(parse_lora_key("something.else").is_none());
}
#[test]
fn load_lora_from_bytes_roundtrip() {
let a_data: Vec<f32> = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]; let b_data: Vec<f32> = vec![1.0, 0.0, 0.0, 1.0]; let alpha_data: Vec<f32> = vec![4.0];
let a_bytes: Vec<u8> = a_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let b_bytes: Vec<u8> = b_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let alpha_bytes: Vec<u8> = alpha_data.iter().flat_map(|f| f.to_le_bytes()).collect();
let a_len = a_bytes.len();
let b_len = b_bytes.len();
let alpha_len = alpha_bytes.len();
let b_start = a_len;
let alpha_start = a_len + b_len;
let header = format!(
r#"{{
"base_model.model.layers.0.self_attn.q_proj.lora_A.weight": {{"dtype":"F32","shape":[2,3],"data_offsets":[0,{a_len}]}},
"base_model.model.layers.0.self_attn.q_proj.lora_B.weight": {{"dtype":"F32","shape":[2,2],"data_offsets":[{b_start},{alpha_start}]}},
"base_model.model.layers.0.self_attn.q_proj.lora_alpha": {{"dtype":"F32","shape":[1],"data_offsets":[{alpha_start},{total}]}}
}}"#,
a_len = a_len,
b_start = b_start,
alpha_start = alpha_start,
total = alpha_start + alpha_len,
);
let mut tensor_data = Vec::new();
tensor_data.extend_from_slice(&a_bytes);
tensor_data.extend_from_slice(&b_bytes);
tensor_data.extend_from_slice(&alpha_bytes);
let header_bytes = header.as_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(header_bytes);
buf.extend_from_slice(&tensor_data);
let adapter = load_lora_from_bytes(&buf).unwrap();
assert_eq!(adapter.adapters.len(), 1);
let layer = &adapter.adapters["model.layers.0.self_attn.q_proj.weight"];
assert_eq!(layer.rank, 2);
assert_eq!(layer.in_features, 3);
assert_eq!(layer.out_features, 2);
assert!((layer.alpha - 4.0).abs() < 1e-6);
assert_eq!(layer.lora_a, a_data);
assert_eq!(layer.lora_b, b_data);
}
#[test]
fn f16_to_f32_one() {
assert_eq!(f16_to_f32(0x3C00), 1.0f32);
}
#[test]
fn f16_to_f32_zero() {
assert_eq!(f16_to_f32(0x0000), 0.0f32);
}
#[test]
fn bf16_to_f32_one() {
assert_eq!(bf16_to_f32(0x3F80), 1.0f32);
}
}