Skip to main content

entrenar/lora/adapter/
merge_export.rs

1//! Merged model export — merge LoRA/QLoRA adapters into base weights and export
2//!
3//! Supports merging adapters back into the base model and collecting the
4//! merged weights. When the `hub` feature is enabled, also supports direct
5//! export to SafeTensors and GGUF formats.
6
7use super::error::AdapterError;
8use crate::lora::LoRALayer;
9use crate::lora::QLoRALayer;
10use std::collections::HashMap;
11use std::path::Path;
12
13/// Merged model from combining LoRA/QLoRA adapters with base weights
14pub struct MergedModel {
15    /// Tensor data by name (merged base + adapter)
16    pub tensors: HashMap<String, Vec<f32>>,
17    /// Tensor shapes by name
18    pub shapes: HashMap<String, Vec<usize>>,
19    /// Number of layers merged
20    pub layers_merged: usize,
21}
22
23impl MergedModel {
24    /// Total parameter count
25    pub fn param_count(&self) -> u64 {
26        self.tensors.values().map(|t| t.len() as u64).sum()
27    }
28
29    /// Save merged model as SafeTensors
30    pub fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), AdapterError> {
31        use safetensors::tensor::{Dtype, TensorView};
32
33        let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = self
34            .tensors
35            .iter()
36            .map(|(name, data)| {
37                let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
38                let shape = self.shapes.get(name).cloned().unwrap_or_else(|| vec![data.len()]);
39                (name.clone(), bytes, shape)
40            })
41            .collect();
42
43        let views: Vec<(&str, TensorView<'_>)> = tensor_data
44            .iter()
45            .map(|(name, bytes, shape)| {
46                let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
47                    .expect("TensorView construction must not fail for valid F32 data");
48                (name.as_str(), view)
49            })
50            .collect();
51
52        let mut metadata = std::collections::HashMap::new();
53        metadata.insert("format".to_string(), "entrenar-merged".to_string());
54
55        let safetensor_bytes = safetensors::serialize(views, Some(metadata))
56            .map_err(|e| AdapterError::SafeTensors(format!("Serialization failed: {e}")))?;
57
58        std::fs::write(path, safetensor_bytes)?;
59        Ok(())
60    }
61}
62
63/// Merge LoRA layers into base weights and collect as merged model
64///
65/// Each entry is (layer_name, LoRALayer). The LoRA layer is cloned and merged,
66/// producing the merged base weight.
67pub fn merge_and_collect(layers: &[(&str, &LoRALayer)]) -> MergedModel {
68    let mut tensors = HashMap::new();
69    let mut shapes = HashMap::new();
70
71    for &(name, layer) in layers {
72        let mut cloned = layer.clone();
73        cloned.merge();
74        let data = cloned.base_weight().data().to_vec();
75        shapes.insert(name.to_string(), vec![layer.d_out(), layer.d_in()]);
76        tensors.insert(name.to_string(), data);
77    }
78
79    MergedModel { layers_merged: layers.len(), tensors, shapes }
80}
81
82/// Merge QLoRA layers into f32 weights and collect as merged model
83///
84/// Dequantizes 4-bit base + adapter contribution for each layer.
85pub fn merge_qlora_and_collect(layers: &[(&str, &QLoRALayer)]) -> MergedModel {
86    let mut tensors = HashMap::new();
87    let mut shapes = HashMap::new();
88
89    for &(name, layer) in layers {
90        let data = layer.merge_to_f32();
91        shapes.insert(name.to_string(), vec![layer.d_out(), layer.d_in()]);
92        tensors.insert(name.to_string(), data);
93    }
94
95    MergedModel { layers_merged: layers.len(), tensors, shapes }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::Tensor;
102    use tempfile::TempDir;
103
104    fn make_lora_layer(d_out: usize, d_in: usize, rank: usize) -> LoRALayer {
105        let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
106        LoRALayer::new(base, d_out, d_in, rank, 8.0)
107    }
108
109    #[test]
110    fn test_merge_and_collect_lora() {
111        let layer1 = make_lora_layer(8, 16, 4);
112        let layer2 = make_lora_layer(8, 16, 4);
113
114        let layers: Vec<(&str, &LoRALayer)> = vec![
115            ("model.layers.0.q_proj.weight", &layer1),
116            ("model.layers.0.v_proj.weight", &layer2),
117        ];
118
119        let merged = merge_and_collect(&layers);
120
121        assert_eq!(merged.layers_merged, 2);
122        assert_eq!(merged.tensors.len(), 2);
123        assert!(merged.param_count() > 0);
124    }
125
126    #[test]
127    fn test_merge_qlora_and_collect() {
128        let base = Tensor::from_vec(vec![0.5; 8 * 16], false);
129        let qlora = QLoRALayer::new(base, 8, 16, 4, 8.0);
130
131        let layers: Vec<(&str, &QLoRALayer)> = vec![("model.layers.0.q_proj.weight", &qlora)];
132
133        let merged = merge_qlora_and_collect(&layers);
134
135        assert_eq!(merged.layers_merged, 1);
136        assert_eq!(merged.tensors.len(), 1);
137
138        let data = merged.tensors.get("model.layers.0.q_proj.weight").expect("key should exist");
139        assert_eq!(data.len(), 8 * 16);
140    }
141
142    #[test]
143    fn test_save_safetensors() {
144        let layer = make_lora_layer(8, 8, 4);
145        let layers: Vec<(&str, &LoRALayer)> = vec![("weight", &layer)];
146        let merged = merge_and_collect(&layers);
147
148        let tmp = TempDir::new().expect("temp file creation should succeed");
149        let path = tmp.path().join("merged.safetensors");
150        merged.save_safetensors(&path).expect("save should succeed");
151
152        // Verify file exists and is valid safetensors
153        let data = std::fs::read(&path).expect("file read should succeed");
154        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
155        assert_eq!(loaded.len(), 1);
156        let names = loaded.names();
157        assert!(names.contains(&"weight"));
158    }
159
160    #[test]
161    fn test_merge_empty() {
162        let layers: Vec<(&str, &LoRALayer)> = vec![];
163        let merged = merge_and_collect(&layers);
164        assert_eq!(merged.layers_merged, 0);
165        assert!(merged.tensors.is_empty());
166    }
167
168    #[test]
169    fn test_merge_preserves_shapes() {
170        let layer = make_lora_layer(8, 16, 4);
171        let layers: Vec<(&str, &LoRALayer)> = vec![("w", &layer)];
172        let merged = merge_and_collect(&layers);
173
174        assert_eq!(merged.shapes.get("w").expect("key should exist"), &vec![8, 16]);
175    }
176}