1use std::collections::HashMap;
10use std::fs;
11use std::path::Path;
12
13use candle_core::{Device, Tensor};
14use serde::{de::DeserializeOwned, Serialize};
15
16use crate::error::{PeftError, Result};
17
18pub trait SaveLoad {
20 fn state_dict(&self) -> Result<HashMap<String, Tensor>>;
26
27 fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()>;
33}
34
35pub fn save_adapter_weights<P: AsRef<Path>>(adapter: &dyn SaveLoad, path: P) -> Result<()> {
47 let state_dict = adapter.state_dict()?;
48
49 let tensors: Vec<(&str, Tensor)> = state_dict
51 .iter()
52 .map(|(name, tensor)| (name.as_str(), tensor.clone()))
53 .collect();
54
55 safetensors::tensor::serialize_to_file(tensors, None, path.as_ref())
57 .map_err(|e| PeftError::Io(format!("Failed to save safetensors: {e}")))?;
58
59 Ok(())
60}
61
62pub fn load_adapter_weights<P: AsRef<Path>>(
75 adapter: &mut dyn SaveLoad,
76 path: P,
77 device: &Device,
78) -> Result<()> {
79 let tensors = candle_core::safetensors::load(path.as_ref(), device)?;
81
82 adapter.load_state_dict(tensors)?;
84
85 Ok(())
86}
87
88pub fn save_adapter_config<T: Serialize, P: AsRef<Path>>(config: &T, path: P) -> Result<()> {
97 let json = serde_json::to_string_pretty(config)
98 .map_err(|e| PeftError::Io(format!("Failed to serialize config: {e}")))?;
99
100 fs::write(path, json)
101 .map_err(|e| PeftError::Io(format!("Failed to write config file: {e}")))?;
102
103 Ok(())
104}
105
106pub fn load_adapter_config<T: DeserializeOwned, P: AsRef<Path>>(path: P) -> Result<T> {
114 let json = fs::read_to_string(path)
115 .map_err(|e| PeftError::Io(format!("Failed to read config file: {e}")))?;
116
117 let config = serde_json::from_str(&json)
118 .map_err(|e| PeftError::Io(format!("Failed to parse config: {e}")))?;
119
120 Ok(config)
121}
122
123pub const ADAPTER_WEIGHTS_FILENAME: &str = "adapter_model.safetensors";
125
126pub const ADAPTER_CONFIG_FILENAME: &str = "adapter_config.json";
128
129pub fn save_pretrained<T: Serialize, P: AsRef<Path>>(
153 adapter: &dyn SaveLoad,
154 config: &T,
155 dir: P,
156) -> Result<()> {
157 let dir = dir.as_ref();
158
159 if !dir.exists() {
161 fs::create_dir_all(dir)
162 .map_err(|e| PeftError::Io(format!("Failed to create directory: {e}")))?;
163 }
164
165 let weights_path = dir.join(ADAPTER_WEIGHTS_FILENAME);
167 save_adapter_weights(adapter, &weights_path)?;
168
169 let config_path = dir.join(ADAPTER_CONFIG_FILENAME);
171 save_adapter_config(config, &config_path)?;
172
173 Ok(())
174}
175
176pub fn load_pretrained<T: DeserializeOwned, P: AsRef<Path>>(
203 adapter: &mut dyn SaveLoad,
204 dir: P,
205 device: &Device,
206) -> Result<T> {
207 let dir = dir.as_ref();
208
209 if !dir.exists() {
210 return Err(PeftError::Io(format!(
211 "Directory does not exist: {}",
212 dir.display()
213 )));
214 }
215
216 let weights_path = dir.join(ADAPTER_WEIGHTS_FILENAME);
218 load_adapter_weights(adapter, &weights_path, device)?;
219
220 let config_path = dir.join(ADAPTER_CONFIG_FILENAME);
222 let config: T = load_adapter_config(&config_path)?;
223
224 Ok(config)
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use candle_core::Device;
231 use std::collections::HashMap;
232 use tempfile::TempDir;
233
234 struct MockAdapter {
235 weights: HashMap<String, Tensor>,
236 }
237
238 impl SaveLoad for MockAdapter {
239 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
240 Ok(self.weights.clone())
241 }
242
243 fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
244 self.weights = state_dict;
245 Ok(())
246 }
247 }
248
249 #[test]
250 #[allow(clippy::similar_names)]
251 fn test_save_load_adapter_weights() -> anyhow::Result<()> {
252 let device = Device::Cpu;
253 let temp_dir = TempDir::new()?;
254 let weights_path = temp_dir.path().join("adapter.safetensors");
255
256 let mut weights = HashMap::new();
258 weights.insert(
259 "lora_a".to_string(),
260 Tensor::randn(0f32, 1f32, (64, 8), &device)?,
261 );
262 weights.insert(
263 "lora_b".to_string(),
264 Tensor::randn(0f32, 1f32, (8, 64), &device)?,
265 );
266
267 let adapter = MockAdapter {
268 weights: weights.clone(),
269 };
270
271 save_adapter_weights(&adapter, &weights_path)?;
273 assert!(weights_path.exists());
274
275 let mut loaded_adapter = MockAdapter {
277 weights: HashMap::new(),
278 };
279 load_adapter_weights(&mut loaded_adapter, &weights_path, &device)?;
280
281 assert_eq!(loaded_adapter.weights.len(), 2);
283 assert!(loaded_adapter.weights.contains_key("lora_a"));
284 assert!(loaded_adapter.weights.contains_key("lora_b"));
285
286 assert_eq!(
288 loaded_adapter.weights["lora_a"].dims(),
289 weights["lora_a"].dims()
290 );
291 assert_eq!(
292 loaded_adapter.weights["lora_b"].dims(),
293 weights["lora_b"].dims()
294 );
295
296 let original_a_sum = weights["lora_a"].sum_all()?.to_scalar::<f32>()?;
298 let loaded_a_sum = loaded_adapter.weights["lora_a"]
299 .sum_all()?
300 .to_scalar::<f32>()?;
301 assert!(
302 (original_a_sum - loaded_a_sum).abs() < 1e-5,
303 "lora_a sum mismatch: {original_a_sum} vs {loaded_a_sum}"
304 );
305
306 let original_b_sum = weights["lora_b"].sum_all()?.to_scalar::<f32>()?;
307 let loaded_b_sum = loaded_adapter.weights["lora_b"]
308 .sum_all()?
309 .to_scalar::<f32>()?;
310 assert!(
311 (original_b_sum - loaded_b_sum).abs() < 1e-5,
312 "lora_b sum mismatch: {original_b_sum} vs {loaded_b_sum}"
313 );
314
315 Ok(())
316 }
317
318 #[test]
319 fn test_save_load_config() -> anyhow::Result<()> {
320 use serde::{Deserialize, Serialize};
321
322 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
323 struct TestConfig {
324 r: usize,
325 alpha: usize,
326 dropout: f64,
327 }
328
329 let temp_dir = TempDir::new()?;
330 let config_path = temp_dir.path().join("config.json");
331
332 let config = TestConfig {
333 r: 8,
334 alpha: 16,
335 dropout: 0.1,
336 };
337
338 save_adapter_config(&config, &config_path)?;
340 assert!(config_path.exists());
341
342 let loaded_config: TestConfig = load_adapter_config(&config_path)?;
344 assert_eq!(config, loaded_config);
345
346 Ok(())
347 }
348
349 #[test]
350 fn test_save_load_pretrained() -> anyhow::Result<()> {
351 use serde::{Deserialize, Serialize};
352
353 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
354 struct TestConfig {
355 r: usize,
356 alpha: usize,
357 }
358
359 let device = Device::Cpu;
360 let temp_dir = TempDir::new()?;
361
362 let mut weights = HashMap::new();
364 weights.insert(
365 "lora_a".to_string(),
366 Tensor::randn(0f32, 1f32, (64, 8), &device)?,
367 );
368 weights.insert(
369 "lora_b".to_string(),
370 Tensor::randn(0f32, 1f32, (8, 64), &device)?,
371 );
372
373 let adapter = MockAdapter {
374 weights: weights.clone(),
375 };
376
377 let config = TestConfig { r: 8, alpha: 16 };
378
379 save_pretrained(&adapter, &config, temp_dir.path())?;
381
382 assert!(temp_dir.path().join(ADAPTER_WEIGHTS_FILENAME).exists());
384 assert!(temp_dir.path().join(ADAPTER_CONFIG_FILENAME).exists());
385
386 let mut loaded_adapter = MockAdapter {
388 weights: HashMap::new(),
389 };
390 let loaded_config: TestConfig =
391 load_pretrained(&mut loaded_adapter, temp_dir.path(), &device)?;
392
393 assert_eq!(config, loaded_config);
395
396 assert_eq!(loaded_adapter.weights.len(), 2);
398 assert!(loaded_adapter.weights.contains_key("lora_a"));
399 assert!(loaded_adapter.weights.contains_key("lora_b"));
400
401 Ok(())
402 }
403}