entrenar/lora/adapter/
merge_export.rs1use super::error::AdapterError;
8use crate::lora::LoRALayer;
9use crate::lora::QLoRALayer;
10use std::collections::HashMap;
11use std::path::Path;
12
13pub struct MergedModel {
15 pub tensors: HashMap<String, Vec<f32>>,
17 pub shapes: HashMap<String, Vec<usize>>,
19 pub layers_merged: usize,
21}
22
23impl MergedModel {
24 pub fn param_count(&self) -> u64 {
26 self.tensors.values().map(|t| t.len() as u64).sum()
27 }
28
29 pub fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), AdapterError> {
31 use safetensors::tensor::{Dtype, TensorView};
32
33 let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = self
34 .tensors
35 .iter()
36 .map(|(name, data)| {
37 let bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
38 let shape = self.shapes.get(name).cloned().unwrap_or_else(|| vec![data.len()]);
39 (name.clone(), bytes, shape)
40 })
41 .collect();
42
43 let views: Vec<(&str, TensorView<'_>)> = tensor_data
44 .iter()
45 .map(|(name, bytes, shape)| {
46 let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
47 .expect("TensorView construction must not fail for valid F32 data");
48 (name.as_str(), view)
49 })
50 .collect();
51
52 let mut metadata = std::collections::HashMap::new();
53 metadata.insert("format".to_string(), "entrenar-merged".to_string());
54
55 let safetensor_bytes = safetensors::serialize(views, Some(metadata))
56 .map_err(|e| AdapterError::SafeTensors(format!("Serialization failed: {e}")))?;
57
58 std::fs::write(path, safetensor_bytes)?;
59 Ok(())
60 }
61}
62
63pub fn merge_and_collect(layers: &[(&str, &LoRALayer)]) -> MergedModel {
68 let mut tensors = HashMap::new();
69 let mut shapes = HashMap::new();
70
71 for &(name, layer) in layers {
72 let mut cloned = layer.clone();
73 cloned.merge();
74 let data = cloned.base_weight().data().to_vec();
75 shapes.insert(name.to_string(), vec![layer.d_out(), layer.d_in()]);
76 tensors.insert(name.to_string(), data);
77 }
78
79 MergedModel { layers_merged: layers.len(), tensors, shapes }
80}
81
82pub fn merge_qlora_and_collect(layers: &[(&str, &QLoRALayer)]) -> MergedModel {
86 let mut tensors = HashMap::new();
87 let mut shapes = HashMap::new();
88
89 for &(name, layer) in layers {
90 let data = layer.merge_to_f32();
91 shapes.insert(name.to_string(), vec![layer.d_out(), layer.d_in()]);
92 tensors.insert(name.to_string(), data);
93 }
94
95 MergedModel { layers_merged: layers.len(), tensors, shapes }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::Tensor;
102 use tempfile::TempDir;
103
104 fn make_lora_layer(d_out: usize, d_in: usize, rank: usize) -> LoRALayer {
105 let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
106 LoRALayer::new(base, d_out, d_in, rank, 8.0)
107 }
108
109 #[test]
110 fn test_merge_and_collect_lora() {
111 let layer1 = make_lora_layer(8, 16, 4);
112 let layer2 = make_lora_layer(8, 16, 4);
113
114 let layers: Vec<(&str, &LoRALayer)> = vec![
115 ("model.layers.0.q_proj.weight", &layer1),
116 ("model.layers.0.v_proj.weight", &layer2),
117 ];
118
119 let merged = merge_and_collect(&layers);
120
121 assert_eq!(merged.layers_merged, 2);
122 assert_eq!(merged.tensors.len(), 2);
123 assert!(merged.param_count() > 0);
124 }
125
126 #[test]
127 fn test_merge_qlora_and_collect() {
128 let base = Tensor::from_vec(vec![0.5; 8 * 16], false);
129 let qlora = QLoRALayer::new(base, 8, 16, 4, 8.0);
130
131 let layers: Vec<(&str, &QLoRALayer)> = vec![("model.layers.0.q_proj.weight", &qlora)];
132
133 let merged = merge_qlora_and_collect(&layers);
134
135 assert_eq!(merged.layers_merged, 1);
136 assert_eq!(merged.tensors.len(), 1);
137
138 let data = merged.tensors.get("model.layers.0.q_proj.weight").expect("key should exist");
139 assert_eq!(data.len(), 8 * 16);
140 }
141
142 #[test]
143 fn test_save_safetensors() {
144 let layer = make_lora_layer(8, 8, 4);
145 let layers: Vec<(&str, &LoRALayer)> = vec![("weight", &layer)];
146 let merged = merge_and_collect(&layers);
147
148 let tmp = TempDir::new().expect("temp file creation should succeed");
149 let path = tmp.path().join("merged.safetensors");
150 merged.save_safetensors(&path).expect("save should succeed");
151
152 let data = std::fs::read(&path).expect("file read should succeed");
154 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
155 assert_eq!(loaded.len(), 1);
156 let names = loaded.names();
157 assert!(names.contains(&"weight"));
158 }
159
160 #[test]
161 fn test_merge_empty() {
162 let layers: Vec<(&str, &LoRALayer)> = vec![];
163 let merged = merge_and_collect(&layers);
164 assert_eq!(merged.layers_merged, 0);
165 assert!(merged.tensors.is_empty());
166 }
167
168 #[test]
169 fn test_merge_preserves_shapes() {
170 let layer = make_lora_layer(8, 16, 4);
171 let layers: Vec<(&str, &LoRALayer)> = vec![("w", &layer)];
172 let merged = merge_and_collect(&layers);
173
174 assert_eq!(merged.shapes.get("w").expect("key should exist"), &vec![8, 16]);
175 }
176}