use std::collections::HashMap;
use std::path::Path;
use candle_core::{DType, Device, Tensor};
use safetensors::SafeTensors;
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct LoraConfig {
pub r: usize,
pub lora_alpha: f64,
#[serde(default)]
pub target_modules: Option<Vec<String>>,
#[serde(default)]
pub base_model_name_or_path: Option<String>,
#[serde(default)]
pub fan_in_fan_out: bool,
}
pub struct LoraModule {
pub lora_a: Tensor,
pub lora_b: Tensor,
}
pub struct LoraAdapter {
pub config: LoraConfig,
pub modules: HashMap<String, LoraModule>,
}
impl LoraAdapter {
pub fn load(adapter_dir: &Path, device: &Device) -> crate::Result<Self> {
let config_path = adapter_dir.join("adapter_config.json");
let cfg_str = std::fs::read_to_string(&config_path).map_err(|e| {
crate::Error::Backend(format!("lora: read {}: {e}", config_path.display()))
})?;
let config: LoraConfig = serde_json::from_str(&cfg_str).map_err(|e| {
crate::Error::Backend(format!("lora: parse {}: {e}", config_path.display()))
})?;
if config.r == 0 {
return Err(crate::Error::Backend(
"lora: adapter_config.json has r=0; refusing to merge".into(),
));
}
let weights_path = if adapter_dir.join("adapter_model.safetensors").exists() {
adapter_dir.join("adapter_model.safetensors")
} else if adapter_dir.join("adapter_weights.safetensors").exists() {
adapter_dir.join("adapter_weights.safetensors")
} else {
return Err(crate::Error::Backend(format!(
"lora: no adapter_model.safetensors or adapter_weights.safetensors in {}",
adapter_dir.display()
)));
};
let bytes = std::fs::read(&weights_path).map_err(|e| {
crate::Error::Backend(format!("lora: read {}: {e}", weights_path.display()))
})?;
let st = SafeTensors::deserialize(&bytes).map_err(|e| {
crate::Error::Backend(format!("lora: deserialize {}: {e}", weights_path.display()))
})?;
let mut by_module: HashMap<String, (Option<Tensor>, Option<Tensor>)> = HashMap::new();
for (key, view) in st.tensors() {
let (module_path, slot) = parse_lora_key(&key)?;
let shape: Vec<usize> = view.shape().to_vec();
if view.dtype() != safetensors::Dtype::F32 {
return Err(crate::Error::Backend(format!(
"lora: {key}: dtype {:?} not supported (Phase 4 ships fp32 only)",
view.dtype()
)));
}
let bytes = view.data();
let n = bytes.len() / 4;
let mut data = Vec::with_capacity(n);
for chunk in bytes.chunks_exact(4) {
data.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
let tensor = Tensor::from_vec(data, shape, device)
.map_err(|e| crate::Error::Backend(format!("lora: tensor {key}: {e}")))?;
let entry = by_module.entry(module_path).or_default();
match slot {
LoraSlot::A => entry.0 = Some(tensor),
LoraSlot::B => entry.1 = Some(tensor),
}
}
let mut modules = HashMap::new();
for (path, (a, b)) in by_module {
let lora_a = a.ok_or_else(|| {
crate::Error::Backend(format!("lora: missing lora_A for module {path}"))
})?;
let lora_b = b.ok_or_else(|| {
crate::Error::Backend(format!("lora: missing lora_B for module {path}"))
})?;
modules.insert(path, LoraModule { lora_a, lora_b });
}
Ok(Self { config, modules })
}
}
#[derive(Debug, Clone, Copy)]
enum LoraSlot {
A,
B,
}
fn parse_lora_key(key: &str) -> crate::Result<(String, LoraSlot)> {
let stripped = key.strip_prefix("base_model.model.").ok_or_else(|| {
crate::Error::Backend(format!(
"lora: key {key} does not start with 'base_model.model.'"
))
})?;
if let Some(path) = stripped.strip_suffix(".lora_A.weight") {
Ok((path.to_string(), LoraSlot::A))
} else if let Some(path) = stripped.strip_suffix(".lora_B.weight") {
Ok((path.to_string(), LoraSlot::B))
} else {
Err(crate::Error::Backend(format!(
"lora: key {key} does not end with '.lora_A.weight' or '.lora_B.weight'"
)))
}
}
pub(crate) fn merge_into_base(
base_safetensors: &Path,
adapter: &LoraAdapter,
device: &Device,
) -> crate::Result<HashMap<String, Tensor>> {
let bytes = std::fs::read(base_safetensors).map_err(|e| {
crate::Error::Backend(format!(
"lora_merge: read {}: {e}",
base_safetensors.display()
))
})?;
let st = SafeTensors::deserialize(&bytes).map_err(|e| {
crate::Error::Backend(format!(
"lora_merge: deserialize {}: {e}",
base_safetensors.display()
))
})?;
let scale = adapter.config.lora_alpha / (adapter.config.r as f64);
let mut out: HashMap<String, Tensor> = HashMap::with_capacity(st.tensors().len());
let mut applied: std::collections::HashSet<String> = std::collections::HashSet::new();
for (key, view) in st.tensors() {
let shape: Vec<usize> = view.shape().to_vec();
let mut tensor = decode_view(&view, shape, device)
.map_err(|e| crate::Error::Backend(format!("lora_merge: decode {key}: {e}")))?;
if let Some(mod_path) = key.strip_suffix(".weight") {
if let Some(lora_mod) = adapter.modules.get(mod_path) {
tensor = apply_lora_delta(
&tensor,
&lora_mod.lora_a,
&lora_mod.lora_b,
scale,
adapter.config.fan_in_fan_out,
)
.map_err(|e| {
crate::Error::Backend(format!("lora_merge: apply delta to {mod_path}: {e}"))
})?;
applied.insert(mod_path.to_string());
}
}
out.insert(key.to_string(), tensor);
}
for adapter_path in adapter.modules.keys() {
if !applied.contains(adapter_path) {
return Err(crate::Error::Backend(format!(
"lora_merge: adapter targets module '{adapter_path}' but no \
matching key '{adapter_path}.weight' found in base safetensors"
)));
}
}
Ok(out)
}
fn decode_view(
view: &safetensors::tensor::TensorView<'_>,
shape: Vec<usize>,
device: &Device,
) -> candle_core::Result<Tensor> {
use safetensors::Dtype as ST;
match view.dtype() {
ST::F32 => {
let bytes = view.data();
let n = bytes.len() / 4;
let mut data = Vec::with_capacity(n);
for chunk in bytes.chunks_exact(4) {
data.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
Tensor::from_vec(data, shape, device)
}
ST::I64 => {
let bytes = view.data();
let n = bytes.len() / 8;
let mut data: Vec<i64> = Vec::with_capacity(n);
for chunk in bytes.chunks_exact(8) {
data.push(i64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]));
}
Tensor::from_vec(data, shape, device)
}
other => Err(candle_core::Error::Msg(format!(
"lora_merge: dtype {other:?} not supported (only F32, I64; the GLiNER2 \
base safetensors has only F32 and I64)"
))),
}
}
fn apply_lora_delta(
base: &Tensor, lora_a: &Tensor, lora_b: &Tensor, scale: f64,
fan_in_fan_out: bool,
) -> candle_core::Result<Tensor> {
let delta = lora_b.matmul(lora_a)?; let delta = (delta * scale)?;
let delta = if fan_in_fan_out {
delta.t()?.contiguous()? } else {
delta
};
if base.shape().dims() != delta.shape().dims() {
return Err(candle_core::Error::Msg(format!(
"lora_merge: base shape {:?} != delta shape {:?} (fan_in_fan_out={fan_in_fan_out})",
base.shape().dims(),
delta.shape().dims(),
)));
}
base.add(&delta)
}
#[allow(dead_code)]
fn _dtype_marker() -> DType {
DType::F32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_lora_key_strict() {
let (path, slot) =
parse_lora_key("base_model.model.encoder.layer.0.attention.self.query.lora_A.weight")
.expect("valid PEFT key should parse");
assert_eq!(path, "encoder.layer.0.attention.self.query");
assert!(matches!(slot, LoraSlot::A));
let (path_b, slot_b) =
parse_lora_key("base_model.model.encoder.layer.0.attention.self.query.lora_B.weight")
.expect("valid PEFT key (B) should parse");
assert_eq!(path_b, "encoder.layer.0.attention.self.query");
assert!(matches!(slot_b, LoraSlot::B));
assert!(
parse_lora_key("encoder.layer.0.attention.self.query.lora_A.weight").is_err(),
"missing 'base_model.model.' prefix should fail"
);
assert!(
parse_lora_key("base_model.model.encoder.layer.0.weight").is_err(),
"missing '.lora_A.weight'/'.lora_B.weight' suffix should fail"
);
}
#[test]
fn apply_lora_delta_shape() {
let device = Device::Cpu;
let base = Tensor::zeros((4, 3), DType::F32, &device).unwrap(); let lora_a = Tensor::ones((2, 3), DType::F32, &device).unwrap(); let lora_b = Tensor::ones((4, 2), DType::F32, &device).unwrap(); let merged = apply_lora_delta(&base, &lora_a, &lora_b, 0.5, false).unwrap();
assert_eq!(merged.shape().dims(), &[4, 3]);
let v = merged.flatten_all().unwrap().to_vec1::<f32>().unwrap();
for x in v {
assert!((x - 1.0).abs() < 1e-6);
}
}
}