use std::collections::HashMap;
use std::path::PathBuf;
use anyhow::{Context, Result};
use candle_core::safetensors::MmapedSafetensors;
use candle_core::{DType, Device, Tensor};
pub struct MergeConfig {
pub base_dir: PathBuf,
pub adapter_path: PathBuf,
pub output_dir: PathBuf,
}
pub enum MergeProgress {
Loading,
Merging { layer: usize, total: usize },
Saving,
Done { output_path: PathBuf },
Failed(String),
}
pub fn start_merge(config: MergeConfig, tx: tokio::sync::mpsc::UnboundedSender<MergeProgress>) {
std::thread::spawn(move || {
if let Err(e) = run_merge(&config, &tx) {
eprintln!("[merge] error: {e:#}");
let _ = tx.send(MergeProgress::Failed(format!("{e:#}")));
}
});
}
#[derive(serde::Deserialize)]
struct MinimalConfig {
num_hidden_layers: usize,
}
#[derive(serde::Deserialize)]
struct IndexJson {
weight_map: HashMap<String, String>,
}
fn run_merge(
config: &MergeConfig,
tx: &tokio::sync::mpsc::UnboundedSender<MergeProgress>,
) -> Result<()> {
let _ = tx.send(MergeProgress::Loading);
let config_path = config.base_dir.join("config.json");
let config_text = std::fs::read_to_string(&config_path)
.with_context(|| format!("failed to read {}", config_path.display()))?;
let model_cfg: MinimalConfig =
serde_json::from_str(&config_text).context("failed to parse config.json")?;
let num_layers = model_cfg.num_hidden_layers;
let index_path = config.base_dir.join("model.safetensors.index.json");
let base_paths: Vec<PathBuf> = if index_path.exists() {
let index_text = std::fs::read_to_string(&index_path)
.with_context(|| format!("failed to read {}", index_path.display()))?;
let index: IndexJson = serde_json::from_str(&index_text)
.context("failed to parse model.safetensors.index.json")?;
let mut filenames: Vec<String> = index.weight_map.into_values().collect();
filenames.sort();
filenames.dedup();
filenames
.into_iter()
.map(|f| config.base_dir.join(f))
.collect()
} else {
vec![config.base_dir.join("model.safetensors")]
};
let base = if base_paths.len() == 1 {
unsafe { MmapedSafetensors::new(&base_paths[0]) }
} else {
unsafe { MmapedSafetensors::multi(&base_paths) }
}
.context("failed to mmap base safetensors")?;
let adapter =
candle_core::safetensors::load(&config.adapter_path, &Device::Cpu).with_context(|| {
format!(
"failed to load adapter from {}",
config.adapter_path.display()
)
})?;
let rank = adapter
.iter()
.find(|(k, _)| k.ends_with("lora_a"))
.map(|(_, t)| t.dim(0))
.context("no lora_a tensor found in adapter")?
.context("failed to read lora_a dimension")?;
let scale = 2.0_f64;
let base_tensor_info = base.tensors();
let mut merged_tensors: HashMap<String, Tensor> = HashMap::new();
let mut merge_layer_idx: usize = 0;
let total_merge_layers = {
let mut count = 0usize;
for i in 0..num_layers {
let q_a = format!("model.layers.{i}.self_attn.q_proj.lora_a");
let v_a = format!("model.layers.{i}.self_attn.v_proj.lora_a");
if adapter.contains_key(&q_a) || adapter.contains_key(&v_a) {
count += 1;
}
}
count
};
let mut reported_layers: std::collections::HashSet<usize> = std::collections::HashSet::new();
for (name, _tv) in &base_tensor_info {
let base_tensor = base
.load(name, &Device::Cpu)
.with_context(|| format!("failed to load base tensor {name}"))?;
let base_dtype = base_tensor.dtype();
let stem = name.strip_suffix(".weight").unwrap_or(name);
let lora_a_key = format!("{stem}.lora_a");
let lora_b_key = format!("{stem}.lora_b");
let tensor = if let (Some(lora_a), Some(lora_b)) =
(adapter.get(&lora_a_key), adapter.get(&lora_b_key))
{
if let Some(layer_num) = extract_layer_index(stem)
&& reported_layers.insert(layer_num)
{
merge_layer_idx += 1;
let _ = tx.send(MergeProgress::Merging {
layer: merge_layer_idx,
total: total_merge_layers,
});
}
let base_f32 = base_tensor.to_dtype(DType::F32)?;
let a_f32 = lora_a.to_dtype(DType::F32)?;
let b_f32 = lora_b.to_dtype(DType::F32)?;
let delta = (b_f32.matmul(&a_f32)? * scale)?;
let merged = (base_f32 + delta)?;
merged.to_dtype(base_dtype)?
} else {
base_tensor
};
merged_tensors.insert(name.clone(), tensor);
}
let _ = rank;
let _ = tx.send(MergeProgress::Saving);
std::fs::create_dir_all(&config.output_dir).with_context(|| {
format!(
"failed to create output dir {}",
config.output_dir.display()
)
})?;
let output_model = config.output_dir.join("model.safetensors");
candle_core::safetensors::save(&merged_tensors, &output_model)
.with_context(|| format!("failed to save merged model to {}", output_model.display()))?;
let copy_files = [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"generation_config.json",
];
for filename in ©_files {
let src = config.base_dir.join(filename);
if src.exists() {
let dst = config.output_dir.join(filename);
if let Err(e) = std::fs::copy(&src, &dst) {
eprintln!("[merge] warning: failed to copy {}: {e:#}", src.display());
}
}
}
let _ = tx.send(MergeProgress::Done {
output_path: output_model,
});
Ok(())
}
fn extract_layer_index(stem: &str) -> Option<usize> {
let parts: Vec<&str> = stem.split('.').collect();
for (i, part) in parts.iter().enumerate() {
if *part == "layers" {
return parts.get(i + 1).and_then(|s| s.parse::<usize>().ok());
}
}
None
}