1use super::error::AdapterError;
6use super::peft_config::PeftAdapterConfig;
7use crate::lora::LoRAConfig;
8use crate::lora::LoRALayer;
9use safetensors::tensor::{Dtype, TensorView};
10use std::collections::HashMap;
11use std::path::Path;
12
13pub struct PeftAdapterBundle {
17 adapters: Vec<(String, AdapterWeights)>,
19 config: LoRAConfig,
21 base_model: Option<String>,
23}
24
25struct AdapterWeights {
27 lora_a: Vec<f32>,
29 lora_b: Vec<f32>,
31 rank: usize,
33 d_in: usize,
35 d_out: usize,
37}
38
39impl PeftAdapterBundle {
40 pub fn new(config: LoRAConfig) -> Self {
42 Self { adapters: Vec::new(), config, base_model: None }
43 }
44
45 pub fn with_base_model(mut self, name: impl Into<String>) -> Self {
47 self.base_model = Some(name.into());
48 self
49 }
50
51 pub fn add_adapter(&mut self, layer_path: impl Into<String>, layer: &LoRALayer) {
56 let weights = AdapterWeights {
57 lora_a: layer.lora_a().data().to_vec(),
58 lora_b: layer.lora_b().data().to_vec(),
59 rank: layer.rank(),
60 d_in: layer.d_in(),
61 d_out: layer.d_out(),
62 };
63 self.adapters.push((layer_path.into(), weights));
64 }
65
66 pub fn add_raw_adapter(
68 &mut self,
69 layer_path: impl Into<String>,
70 lora_a: Vec<f32>,
71 lora_b: Vec<f32>,
72 rank: usize,
73 d_in: usize,
74 d_out: usize,
75 ) {
76 self.adapters
77 .push((layer_path.into(), AdapterWeights { lora_a, lora_b, rank, d_in, d_out }));
78 }
79
80 pub fn save_peft(&self, output_dir: impl AsRef<Path>) -> Result<(), AdapterError> {
86 let output_dir = output_dir.as_ref();
87 std::fs::create_dir_all(output_dir)?;
88
89 let peft_config =
91 PeftAdapterConfig::from_lora_config(&self.config, self.base_model.as_deref());
92 let config_json = peft_config.to_json().map_err(AdapterError::Serialization)?;
93 std::fs::write(output_dir.join("adapter_config.json"), config_json)?;
94
95 let mut tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
98
99 for (layer_path, weights) in &self.adapters {
100 let a_name = format!("base_model.model.{layer_path}.lora_A.weight");
102 let a_bytes: Vec<u8> = bytemuck::cast_slice(&weights.lora_a).to_vec();
103 let a_shape = vec![weights.rank, weights.d_in];
104 tensor_data.push((a_name, a_bytes, a_shape));
105
106 let b_name = format!("base_model.model.{layer_path}.lora_B.weight");
108 let b_bytes: Vec<u8> = bytemuck::cast_slice(&weights.lora_b).to_vec();
109 let b_shape = vec![weights.d_out, weights.rank];
110 tensor_data.push((b_name, b_bytes, b_shape));
111 }
112
113 let views: Vec<(&str, TensorView<'_>)> = tensor_data
115 .iter()
116 .map(|(name, bytes, shape)| {
117 let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
118 .expect("TensorView construction must not fail for valid F32 data");
119 (name.as_str(), view)
120 })
121 .collect();
122
123 let mut metadata = HashMap::new();
125 metadata.insert("format".to_string(), "pt".to_string());
126
127 let safetensor_bytes = safetensors::serialize(views, Some(metadata)).map_err(|e| {
128 AdapterError::SafeTensors(format!("SafeTensors serialization failed: {e}"))
129 })?;
130
131 std::fs::write(output_dir.join("adapter_model.safetensors"), safetensor_bytes)?;
132
133 Ok(())
134 }
135
136 pub fn len(&self) -> usize {
138 self.adapters.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.adapters.is_empty()
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::lora::LoRALayer;
151 use crate::Tensor;
152 use tempfile::TempDir;
153
154 fn make_test_layer(d_out: usize, d_in: usize, rank: usize) -> LoRALayer {
155 let base_weight = Tensor::zeros(d_out * d_in, false);
156 LoRALayer::new(base_weight, d_out, d_in, rank, 16.0)
157 }
158
159 #[test]
160 fn test_bundle_creation() {
161 let config = LoRAConfig::new(8, 16.0).target_qv_projections();
162 let bundle = PeftAdapterBundle::new(config);
163 assert!(bundle.is_empty());
164 assert_eq!(bundle.len(), 0);
165 }
166
167 #[test]
168 fn test_add_adapter() {
169 let config = LoRAConfig::new(8, 16.0).target_qv_projections();
170 let mut bundle = PeftAdapterBundle::new(config);
171
172 let layer = make_test_layer(64, 64, 8);
173 bundle.add_adapter("model.layers.0.self_attn.q_proj", &layer);
174
175 assert_eq!(bundle.len(), 1);
176 assert!(!bundle.is_empty());
177 }
178
179 #[test]
180 fn test_save_peft_creates_files() {
181 let config = LoRAConfig::new(4, 8.0).target_qv_projections();
182 let mut bundle = PeftAdapterBundle::new(config).with_base_model("meta-llama/Llama-2-7b");
183
184 let layer = make_test_layer(16, 16, 4);
185 bundle.add_adapter("model.layers.0.self_attn.q_proj", &layer);
186
187 let tmp = TempDir::new().expect("temp file creation should succeed");
188 bundle.save_peft(tmp.path()).expect("save should succeed");
189
190 assert!(tmp.path().join("adapter_config.json").exists());
192 assert!(tmp.path().join("adapter_model.safetensors").exists());
193 }
194
195 #[test]
196 fn test_save_peft_config_content() {
197 let config = LoRAConfig::new(16, 32.0).target_attention_projections();
198 let bundle = PeftAdapterBundle::new(config).with_base_model("test/model");
199
200 let tmp = TempDir::new().expect("temp file creation should succeed");
201 bundle.save_peft(tmp.path()).expect("save should succeed");
202
203 let json = std::fs::read_to_string(tmp.path().join("adapter_config.json"))
204 .expect("file read should succeed");
205 let parsed: PeftAdapterConfig =
206 serde_json::from_str(&json).expect("JSON deserialization should succeed");
207
208 assert_eq!(parsed.peft_type, "LORA");
209 assert_eq!(parsed.r, 16);
210 assert_eq!(parsed.lora_alpha, 32.0);
211 assert_eq!(parsed.base_model_name_or_path, Some("test/model".to_string()));
212 }
213
214 #[test]
215 fn test_save_peft_safetensors_content() {
216 let config = LoRAConfig::new(4, 8.0).target_qv_projections();
217 let mut bundle = PeftAdapterBundle::new(config);
218
219 let layer = make_test_layer(8, 16, 4);
220 bundle.add_adapter("model.layers.0.self_attn.q_proj", &layer);
221
222 let tmp = TempDir::new().expect("temp file creation should succeed");
223 bundle.save_peft(tmp.path()).expect("save should succeed");
224
225 let data = std::fs::read(tmp.path().join("adapter_model.safetensors"))
227 .expect("file read should succeed");
228 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
229
230 let names = loaded.names();
231 assert!(names.contains(&"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight"));
232 assert!(names.contains(&"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight"));
233
234 let lora_a = loaded
236 .tensor("base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight")
237 .expect("operation should succeed");
238 assert_eq!(lora_a.shape(), &[4, 16]); let lora_b = loaded
241 .tensor("base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight")
242 .expect("operation should succeed");
243 assert_eq!(lora_b.shape(), &[8, 4]); }
245
246 #[test]
247 fn test_save_peft_multiple_layers() {
248 let config = LoRAConfig::new(4, 8.0).target_qv_projections();
249 let mut bundle = PeftAdapterBundle::new(config);
250
251 for i in 0..3 {
252 let layer = make_test_layer(8, 8, 4);
253 bundle.add_adapter(format!("model.layers.{i}.self_attn.q_proj"), &layer);
254 }
255 assert_eq!(bundle.len(), 3);
256
257 let tmp = TempDir::new().expect("temp file creation should succeed");
258 bundle.save_peft(tmp.path()).expect("save should succeed");
259
260 let data = std::fs::read(tmp.path().join("adapter_model.safetensors"))
261 .expect("file read should succeed");
262 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
263 assert_eq!(loaded.len(), 6);
265 }
266
267 #[test]
268 fn test_save_peft_empty_bundle() {
269 let config = LoRAConfig::new(4, 8.0);
270 let bundle = PeftAdapterBundle::new(config);
271
272 let tmp = TempDir::new().expect("temp file creation should succeed");
273 bundle.save_peft(tmp.path()).expect("save should succeed");
274
275 assert!(tmp.path().join("adapter_config.json").exists());
277 assert!(tmp.path().join("adapter_model.safetensors").exists());
278 }
279}